Skip to content

Commit 8370e5f

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Refactor Objective, OutcomeConstraint to be expression-based; remove Metrics from OptimizationConfig
Summary: This diff is a major refactor. In this diff we remove Metric instances from both Objective and OutcomeConstraint (and transitively, also off OptimizationConfig) and instead add them directly to the Experiment. In their place, we use the "expression" syntax and Sympy parsing from the Ax API directly on these classes, allowing users to specify `Objective(expression="ne1 + 2*ne2")` or `OutcomeConstraint(expression="qps >= 7000")` directly. This has a number of benefits: * Dependency spread is contained. OptimizationConfig is now able to be loaded without loading the any specific Metric implementation * Major separation of concerns: OptimizationConfig is now solely concerned with the goals of the optimization and not also implicitly concerned with fetching logic * Bringing Expressions directly onto Objective and OutcomeConstraint allows us to deprecated ScalarizedObjective, MultiObjective, ScalarizedOutcomeConstraint, and ObjectiveThreshold * Massive simplification for eventual upstreaming of MultiTypeExperiment functionality into base Experiment This is also a necessary step in our storage redesign in order to make metrics "live" only in one place. See this doc for details: https://docs.google.com/document/d/1I8FSQJ05_WHXFHBtNaL3dTAdAbvKfADwMRlIdq6BEV8/edit?tab=t.0#heading=h.s4r4rlrnva07 The refactor touches several core modules, please focus review on these: * ax/core/objective.py * ax/core/outcome_constraint.py * ax/core/optimization_config.py * ax/core/experiment.py These changes mandate many many updates to callsites which is ballooning the size of the diff. Changes are summarized here. # Core changes: * Objective, OutcomeConstraint classes now takes an expression string (e.g., "accuracy", "-loss", "2*acc + recall", "acc, -loss") * Old-style constructors still work, but they discard the Metric after initialization * Objective.metric property no longer exists. If a user wants to get the specific metric they retrieve its name via get_names and pass that string into Experiment.get_metric(name: str) -> Metric * isinstance checks for MultiObective, ScalarizedObjective, etc are replaced with is_multi_objective and is_scalarized_objective properties * Experiment.tracking_metrics is a property not an attr, now experiment simply has a collection of Metrics called "metrics" and tracking_metrics is a property which returns the metrics not on the optimization config * **Parsing is directly lifted and shifted from ax/api** * **Both JSON and SQL encoders/decoders carefully reconstruct the pre-refactor structure, allowing existing storage to be unaffected** # Other changes ## **ax/adapter/** - **adapter_utils.py** - Functions updated to accept experiment metrics for lookup. - All extraction functions use metric names. - **base.py, torch.py, transforms/** - Updated to thread experiment metrics through extraction functions. - Metric access patterns updated. ## **ax/analysis/** - Updated all metric access patterns to use metric names, not Metric objects. ## **ax/api/** - **client.py** - Simplified metric overwriting logic: metrics only live on Experiment. - **utils/instantiation/from_string.py** - Simplified parsing: constructs Objective directly from expression string. ## **ax/benchmark/** - Updated objective construction to use expression-based interface. ## **ax/core/** - **experiment.py** - Experiment now maintains a flat list of metrics (`_metrics: dict[str, Metric]`). - No distinction between tracking and optimization metrics. - Objective and constraints reference metrics by name, not object. - **objectives.py** - Unified Objective class: now takes an expression string, supporting single, scalarized, and multi-objective optimization. - Deprecated MultiObjective and ScalarizedObjective classes; their functionality is merged into Objective. - All metric access is via metric names/weights parsed from the expression, not direct Metric objects. - **optimization_config.py** - Unified OptimizationConfig class: handles both single and multi-objective cases. - Removed MultiObjectiveOptimizationConfig. - All config logic uses expression-based Objective and OutcomeConstraint. - **outcome_constraint.py** - Unified OutcomeConstraint class: now takes an expression string, supporting scalarized constraints. - Removed ObjectiveThreshold and ScalarizedOutcomeConstraint classes. - All validation and access is via metric names/weights parsed from the expression. - **trial.py** - Updated to reference metrics via `objective.metric_names[0]` instead of `objective.metric.name`. - **utils.py** - Updated metric access patterns to use metric names. ## **ax/early_stopping/** - Updated to reference metric signatures via experiment metrics, not via Objective. ## **ax/plot/** - Updated metric access patterns to use metric names. ## **ax/service/** - **ax_client.py** - Updated to reference metrics via `objective.metric_names[0]`. - Simplified metric overwriting logic. - **utils/best_point.py** - Updated metric access patterns to use metric names. ## **ax/storage/json_store/** - Encoder: serializes objectives as expression strings. - Decoder: converts legacy formats (with embedded Metric dicts) to expression-based objects. ## **ax/storage/sqa_store/** - Encoder: stores expression string in JSON column. - Decoder: reconstructs expression strings from MetricIntent-based SQA rows. ## **Tests** - All test constructors rewritten to use expression strings. - Assertions updated for new property names. - Added tests for expression parsing, multi-objective, scalarized, and minimize detection. Differential Revision: D93520819
1 parent 1054802 commit 8370e5f

128 files changed

Lines changed: 4623 additions & 2837 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

ax/adapter/adapter_utils.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from ax.core.arm import Arm
2525
from ax.core.experiment import Experiment
26-
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
26+
from ax.core.objective import Objective
2727
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2828
from ax.core.optimization_config import (
2929
MultiObjectiveOptimizationConfig,
@@ -204,6 +204,7 @@ def extract_objective_thresholds(
204204
objective_thresholds: TRefPoint,
205205
objective: Objective,
206206
outcomes: list[str],
207+
experiment: Experiment,
207208
) -> npt.NDArray | None:
208209
"""Extracts objective thresholds' values, in the order of `outcomes`.
209210
@@ -219,6 +220,7 @@ def extract_objective_thresholds(
219220
objective_thresholds: Objective thresholds to extract values from.
220221
objective: The corresponding Objective, for validation purposes.
221222
outcomes: n-length list of names of metrics.
223+
experiment: The experiment, used to map metric names to signatures.
222224
223225
Returns:
224226
(n,) array of thresholds
@@ -228,15 +230,19 @@ def extract_objective_thresholds(
228230

229231
objective_threshold_dict = {}
230232
for ot in objective_thresholds:
233+
ot_signature = experiment.get_metric(ot.metric_names[0]).signature
231234
if ot.relative:
232235
raise ValueError(
233-
f"Objective {ot.metric.signature} has a relative threshold that "
236+
f"Objective {ot_signature} has a relative threshold that "
234237
f"is not supported here."
235238
)
236-
objective_threshold_dict[ot.metric.signature] = ot.bound
239+
objective_threshold_dict[ot_signature] = ot.bound
237240

238241
# Check that all thresholds correspond to a metric.
239-
if set(objective_threshold_dict.keys()).difference(set(objective.metric_names)):
242+
obj_metric_signatures = [
243+
experiment.get_metric(name).signature for name in objective.metric_names
244+
]
245+
if set(objective_threshold_dict.keys()).difference(set(obj_metric_signatures)):
240246
raise ValueError(
241247
"Some objective thresholds do not have corresponding metrics. "
242248
f"Got {objective_thresholds=} and {objective=}."
@@ -250,7 +256,9 @@ def extract_objective_thresholds(
250256
return obj_t
251257

252258

253-
def extract_objective_weights(objective: Objective, outcomes: list[str]) -> npt.NDArray:
259+
def extract_objective_weights(
260+
objective: Objective, outcomes: list[str], experiment: Experiment
261+
) -> npt.NDArray:
254262
"""Extract a weights for objectives.
255263
256264
Weights are for a maximization problem.
@@ -266,29 +274,24 @@ def extract_objective_weights(objective: Objective, outcomes: list[str]) -> npt.
266274
267275
Args:
268276
objective: Objective to extract weights from.
269-
outcomes: n-length list of names of metrics.
277+
outcomes: n-length list of metric signatures.
278+
experiment: The experiment, used to map metric names to signatures.
270279
271280
Returns:
272281
n-length array of weights.
273282
274283
"""
275284
objective_weights = np.zeros(len(outcomes))
276-
if isinstance(objective, ScalarizedObjective):
277-
s = -1.0 if objective.minimize else 1.0
278-
for obj_metric, obj_weight in objective.metric_weights:
279-
objective_weights[outcomes.index(obj_metric.signature)] = obj_weight * s
280-
elif isinstance(objective, MultiObjective):
281-
for obj in objective.objectives:
282-
s = -1.0 if obj.minimize else 1.0
283-
objective_weights[outcomes.index(obj.metric.signature)] = s
284-
else:
285-
s = -1.0 if objective.minimize else 1.0
286-
objective_weights[outcomes.index(objective.metric.signature)] = s
285+
# metric_weights returns sign-encoded (name, weight) tuples for all
286+
# objective types (single, scalarized, multi).
287+
for obj_metric_name, obj_weight in objective.metric_weights:
288+
sig = experiment.get_metric(obj_metric_name).signature
289+
objective_weights[outcomes.index(sig)] = obj_weight
287290
return objective_weights
288291

289292

290293
def extract_objective_weight_matrix(
291-
objective: Objective, outcomes: list[str]
294+
objective: Objective, outcomes: list[str], experiment: Experiment
292295
) -> npt.NDArray:
293296
"""Extract a 2D weight matrix for objectives.
294297
@@ -302,23 +305,31 @@ def extract_objective_weight_matrix(
302305
303306
Args:
304307
objective: Objective to extract weights from.
305-
outcomes: n-length list of names of metrics.
308+
outcomes: n-length list of signatures of metrics.
306309
307310
Returns:
308311
``(n_objectives, n)`` array of weights.
309312
"""
310-
if isinstance(objective, MultiObjective):
313+
if objective.is_multi_objective:
311314
rows: list[npt.NDArray] = []
312-
for obj in objective.objectives:
313-
rows.append(extract_objective_weights(obj, outcomes))
315+
for name, weight in objective.metric_weights:
316+
rows.append(
317+
extract_objective_weights(
318+
objective=Objective(expression=f"{weight} * {name}"),
319+
outcomes=outcomes,
320+
experiment=experiment,
321+
)
322+
)
314323
return np.stack(rows, axis=0)
315324
else:
316325
# Single row – covers Objective and ScalarizedObjective
317-
return extract_objective_weights(objective, outcomes).reshape(1, -1)
326+
return extract_objective_weights(objective, outcomes, experiment).reshape(1, -1)
318327

319328

320329
def extract_outcome_constraints(
321-
outcome_constraints: list[OutcomeConstraint], outcomes: list[str]
330+
outcome_constraints: list[OutcomeConstraint],
331+
outcomes: list[str],
332+
experiment: Experiment,
322333
) -> TBounds:
323334
if len(outcome_constraints) == 0:
324335
return None
@@ -328,11 +339,11 @@ def extract_outcome_constraints(
328339
for i, c in enumerate(outcome_constraints):
329340
s = 1 if c.op == ComparisonOp.LEQ else -1
330341
if isinstance(c, ScalarizedOutcomeConstraint):
331-
for c_metric, c_weight in c.metric_weights:
332-
j = outcomes.index(c_metric.signature)
342+
for c_metric_name, c_weight in c.metric_weights:
343+
j = outcomes.index(experiment.get_metric(c_metric_name).signature)
333344
A[i, j] = s * c_weight
334345
else:
335-
j = outcomes.index(c.metric.signature)
346+
j = outcomes.index(experiment.get_metric(c.metric_names[0]).signature)
336347
A[i, j] = s
337348
b[i, 0] = s * c.bound
338349
return (A, b)
@@ -643,16 +654,20 @@ def get_pareto_frontier_and_configs(
643654
)
644655
# Extract weights, constraints, and objective_thresholds
645656
objective_weights = extract_objective_weight_matrix(
646-
objective=optimization_config.objective, outcomes=adapter.outcomes
657+
objective=optimization_config.objective,
658+
outcomes=adapter.outcomes,
659+
experiment=adapter._experiment,
647660
)
648661
outcome_constraints = extract_outcome_constraints(
649662
outcome_constraints=optimization_config.outcome_constraints,
650663
outcomes=adapter.outcomes,
664+
experiment=adapter._experiment,
651665
)
652666
obj_t = extract_objective_thresholds(
653667
objective_thresholds=optimization_config.objective_thresholds,
654668
objective=optimization_config.objective,
655669
outcomes=adapter.outcomes,
670+
experiment=adapter._experiment,
656671
)
657672
if obj_t is not None:
658673
obj_t = array_to_tensor(obj_t)
@@ -1111,6 +1126,7 @@ def observation_features_to_array(
11111126
def feasible_hypervolume(
11121127
optimization_config: MultiObjectiveOptimizationConfig,
11131128
values: dict[str, npt.NDArray],
1129+
experiment: Experiment,
11141130
) -> npt.NDArray:
11151131
"""Compute the feasible hypervolume each iteration.
11161132
@@ -1119,34 +1135,35 @@ def feasible_hypervolume(
11191135
values: Dictionary from metric name to array of value at each
11201136
iteration (each array is `n`-dim). If optimization config contains
11211137
outcome constraints, values for them must be present in `values`.
1138+
experiment: The experiment, used to map metric names to signatures.
11221139
11231140
Returns: Array of feasible hypervolumes.
11241141
"""
11251142
# Get objective at each iteration
11261143
obj_threshold_dict = {
1127-
ot.metric.signature: ot.bound for ot in optimization_config.objective_thresholds
1144+
experiment.get_metric(ot.metric_names[0]).signature: ot.bound
1145+
for ot in optimization_config.objective_thresholds
11281146
}
1129-
f_vals = np.hstack(
1130-
[
1131-
values[m.signature].reshape(-1, 1)
1132-
for m in optimization_config.objective.metrics
1133-
]
1134-
)
1135-
obj_thresholds = np.array(
1136-
[obj_threshold_dict[m.signature] for m in optimization_config.objective.metrics]
1137-
)
1147+
obj_metric_names = optimization_config.objective.metric_names
1148+
obj_metrics = [experiment.get_metric(name) for name in obj_metric_names]
1149+
f_vals = np.hstack([values[m.signature].reshape(-1, 1) for m in obj_metrics])
1150+
obj_thresholds = np.array([obj_threshold_dict[m.signature] for m in obj_metrics])
11381151
# Set infeasible points to be the objective threshold
11391152
for oc in optimization_config.outcome_constraints:
11401153
if oc.relative:
11411154
raise ValueError(
11421155
"Benchmark aggregation does not support relative constraints"
11431156
)
1144-
g = values[oc.metric.signature]
1157+
oc_sig = experiment.get_metric(oc.metric_names[0]).signature
1158+
g = values[oc_sig]
11451159
feas = g <= oc.bound if oc.op == ComparisonOp.LEQ else g >= oc.bound
11461160
f_vals[~feas] = obj_thresholds
11471161

1162+
# Derive objective directions from the objective's metric_weights.
1163+
# Positive weight = maximize, negative weight = minimize.
1164+
obj_weight_dict = dict(optimization_config.objective.metric_weights)
11481165
obj_weights = np.array(
1149-
[-1 if m.lower_is_better else 1 for m in optimization_config.objective.metrics]
1166+
[1 if obj_weight_dict[name] > 0 else -1 for name in obj_metric_names]
11501167
)
11511168
obj_thresholds = obj_thresholds * obj_weights
11521169
f_vals = f_vals * obj_weights

ax/adapter/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __init__(
201201
"Optimization config is required when "
202202
"`fit_tracking_metrics` is False."
203203
)
204-
self.outcomes = sorted(self._optimization_config.metrics.keys())
204+
self.outcomes = sorted(self._optimization_config.metric_names)
205205

206206
# Set training data (in the raw / untransformed space). This also omits
207207
# out-of-design and abandoned observations depending on the corresponding flags.
@@ -466,10 +466,14 @@ def _set_status_quo(self, experiment: Experiment) -> None:
466466
)
467467
return
468468

469-
if has_map_metrics(optimization_config=self._optimization_config):
469+
if has_map_metrics(
470+
metrics=experiment.get_metrics(
471+
metric_names=[*self._optimization_config.metric_names]
472+
)
473+
):
470474
self._status_quo = _combine_multiple_status_quo_observations(
471475
status_quo_observations=status_quo_observations,
472-
metrics=set(none_throws(self._optimization_config).metrics),
476+
metrics=none_throws(self._optimization_config).metric_names,
473477
)
474478
else:
475479
logger.warning(
@@ -689,7 +693,7 @@ def _get_transformed_gen_args(
689693
# Check that the optimization config has the same metrics as
690694
# the original one. Otherwise, we may attempt to optimize over
691695
# metrics that do not have a fitted model.
692-
outcomes = set(optimization_config.metrics.keys())
696+
outcomes = optimization_config.metric_names
693697
if not outcomes.issubset(self.outcomes):
694698
raise UnsupportedError(
695699
"When fit_tracking_metrics is False, the optimization config "

ax/adapter/cross_validation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ax.adapter.data_utils import ExperimentData
2222
from ax.adapter.observation_utils import unwrap_observation_data
2323
from ax.adapter.torch import TorchAdapter
24+
from ax.core.experiment import Experiment
2425
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2526
from ax.core.optimization_config import OptimizationConfig
2627
from ax.exceptions.core import UnsupportedError
@@ -573,6 +574,7 @@ def assess_model_fit(
573574
def has_good_opt_config_model_fit(
574575
optimization_config: OptimizationConfig,
575576
assess_model_fit_result: AssessModelFitResult,
577+
experiment: Experiment,
576578
) -> bool:
577579
"""Assess model fit for given diagnostics results across the optimization
578580
config metrics
@@ -584,7 +586,8 @@ def has_good_opt_config_model_fit(
584586
585587
Args:
586588
optimization_config: Objective/Outcome constraint metrics to assess
587-
diagnostics: Output of compute_diagnostics
589+
assess_model_fit_result: Output of assess_model_fit
590+
experiment: The experiment, used to map metric names to signatures.
588591
589592
Returns:
590593
Two dictionaries, one for good metrics, one for bad metrics, each
@@ -594,8 +597,9 @@ def has_good_opt_config_model_fit(
594597
# Bad fit criteria: Any objective metrics are poorly fit
595598
# TODO[]: Incl. outcome constraints in assessment
596599
has_good_opt_config_fit = all(
597-
(m.signature in assess_model_fit_result.good_fit_metrics_to_fisher_score)
598-
for m in optimization_config.objective.metrics
600+
experiment.get_metric(name).signature
601+
in assess_model_fit_result.good_fit_metrics_to_fisher_score
602+
for name in optimization_config.objective.metric_names
599603
)
600604
return has_good_opt_config_fit
601605

ax/adapter/discrete.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,18 @@ def _gen(
163163
objective_weights = None
164164
outcome_constraints = None
165165
else:
166-
validate_transformed_optimization_config(optimization_config, self.outcomes)
166+
validate_transformed_optimization_config(
167+
optimization_config, self.outcomes, experiment=self._experiment
168+
)
167169
objective_weights = extract_objective_weights(
168-
objective=optimization_config.objective, outcomes=self.outcomes
170+
objective=optimization_config.objective,
171+
outcomes=self.outcomes,
172+
experiment=self._experiment,
169173
)
170174
outcome_constraints = extract_outcome_constraints(
171175
outcome_constraints=optimization_config.outcome_constraints,
172176
outcomes=self.outcomes,
177+
experiment=self._experiment,
173178
)
174179

175180
# Get fixed features

0 commit comments

Comments
 (0)