Skip to content

Commit 06623f3

Browse files
mpolson64facebook-github-bot
authored andcommitted
Improve ergonomics around metric_name --> signature conversion in Adapter and Transform (#5045)
Summary: D93520819 removed Metric instances from the OptimizationConfig which made it more verbose to get any individual Metric's signature, which we use for bookkeeping throughout the adapter stack. Previously, functions like extract_objective_weights, extract_outcome_constraints, and extract_objective_thresholds accepted an Experiment object solely to call experiment.get_metric(name).signature. This diff replaces that experiment parameter with a lightweight dict[str, str] (metric_name_to_signature) across the adapter and transform layers. Ultimately this reduces coupling to Experiment in the modeling layer, making these functions easier to test and reuse in contexts where a full Experiment object isn't available. Reviewed By: Balandat Differential Revision: D96855090
1 parent d016302 commit 06623f3

13 files changed

Lines changed: 143 additions & 107 deletions

ax/adapter/adapter_utils.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from __future__ import annotations
1010

1111
import warnings
12-
from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
12+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
1313
from copy import deepcopy
1414
from logging import Logger
15-
from typing import Any, cast, SupportsFloat, TYPE_CHECKING
15+
from typing import Any, Callable, cast, SupportsFloat, TYPE_CHECKING
1616

1717
import numpy as np
1818
import numpy.typing as npt
@@ -206,7 +206,7 @@ def extract_objective_thresholds(
206206
objective_thresholds: TRefPoint,
207207
objective: Objective,
208208
outcomes: list[str],
209-
experiment: Experiment,
209+
metric_name_to_signature: Mapping[str, str],
210210
) -> npt.NDArray | None:
211211
"""Extracts objective thresholds' values, in the order of `outcomes`.
212212
@@ -222,7 +222,7 @@ def extract_objective_thresholds(
222222
objective_thresholds: Objective thresholds to extract values from.
223223
objective: The corresponding Objective, for validation purposes.
224224
outcomes: n-length list of names of metrics.
225-
experiment: The experiment, used to map metric names to signatures.
225+
metric_name_to_signature: Mapping from metric names to signatures.
226226
227227
Returns:
228228
(n,) array of thresholds
@@ -232,7 +232,7 @@ def extract_objective_thresholds(
232232

233233
objective_threshold_dict = {}
234234
for ot in objective_thresholds:
235-
ot_signature = experiment.get_metric(ot.metric_names[0]).signature
235+
ot_signature = metric_name_to_signature[ot.metric_names[0]]
236236
if ot.relative:
237237
raise ValueError(
238238
f"Objective {ot_signature} has a relative threshold that "
@@ -242,7 +242,7 @@ def extract_objective_thresholds(
242242

243243
# Check that all thresholds correspond to a metric.
244244
obj_metric_signatures = [
245-
experiment.get_metric(name).signature for name in objective.metric_names
245+
metric_name_to_signature[name] for name in objective.metric_names
246246
]
247247
if set(objective_threshold_dict.keys()).difference(set(obj_metric_signatures)):
248248
raise ValueError(
@@ -259,7 +259,9 @@ def extract_objective_thresholds(
259259

260260

261261
def extract_objective_weights(
262-
objective: Objective, outcomes: list[str], experiment: Experiment
262+
objective: Objective,
263+
outcomes: list[str],
264+
metric_name_to_signature: Mapping[str, str],
263265
) -> npt.NDArray:
264266
"""Extract a weights for objectives.
265267
@@ -277,7 +279,7 @@ def extract_objective_weights(
277279
Args:
278280
objective: Objective to extract weights from.
279281
outcomes: n-length list of metric signatures.
280-
experiment: The experiment, used to map metric names to signatures.
282+
metric_name_to_signature: Mapping from metric names to signatures.
281283
282284
Returns:
283285
n-length array of weights.
@@ -287,13 +289,15 @@ def extract_objective_weights(
287289
# metric_weights returns sign-encoded (name, weight) tuples for all
288290
# objective types (single, scalarized, multi).
289291
for obj_metric_name, obj_weight in objective.metric_weights:
290-
sig = experiment.get_metric(obj_metric_name).signature
292+
sig = metric_name_to_signature[obj_metric_name]
291293
objective_weights[outcomes.index(sig)] = obj_weight
292294
return objective_weights
293295

294296

295297
def extract_objective_weight_matrix(
296-
objective: Objective, outcomes: list[str], experiment: Experiment
298+
objective: Objective,
299+
outcomes: list[str],
300+
metric_name_to_signature: Mapping[str, str],
297301
) -> npt.NDArray:
298302
"""Extract a 2D weight matrix for objectives.
299303
@@ -308,6 +312,7 @@ def extract_objective_weight_matrix(
308312
Args:
309313
objective: Objective to extract weights from.
310314
outcomes: n-length list of signatures of metrics.
315+
metric_name_to_signature: Mapping from metric names to signatures.
311316
312317
Returns:
313318
``(n_objectives, n)`` array of weights.
@@ -319,19 +324,23 @@ def extract_objective_weight_matrix(
319324
extract_objective_weights(
320325
objective=Objective(expression=f"{weight} * {name}"),
321326
outcomes=outcomes,
322-
experiment=experiment,
327+
metric_name_to_signature=metric_name_to_signature,
323328
)
324329
)
325330
return np.stack(rows, axis=0)
326331
else:
327332
# Single row – covers Objective and ScalarizedObjective
328-
return extract_objective_weights(objective, outcomes, experiment).reshape(1, -1)
333+
return extract_objective_weights(
334+
objective=objective,
335+
outcomes=outcomes,
336+
metric_name_to_signature=metric_name_to_signature,
337+
).reshape(1, -1)
329338

330339

331340
def extract_outcome_constraints(
332341
outcome_constraints: list[OutcomeConstraint],
333342
outcomes: list[str],
334-
experiment: Experiment,
343+
metric_name_to_signature: Mapping[str, str],
335344
) -> TBounds:
336345
if len(outcome_constraints) == 0:
337346
return None
@@ -342,10 +351,10 @@ def extract_outcome_constraints(
342351
s = 1 if c.op == ComparisonOp.LEQ else -1
343352
if isinstance(c, ScalarizedOutcomeConstraint):
344353
for c_metric_name, c_weight in c.metric_weights:
345-
j = outcomes.index(experiment.get_metric(c_metric_name).signature)
354+
j = outcomes.index(metric_name_to_signature[c_metric_name])
346355
A[i, j] = s * c_weight
347356
else:
348-
j = outcomes.index(experiment.get_metric(c.metric_names[0]).signature)
357+
j = outcomes.index(metric_name_to_signature[c.metric_names[0]])
349358
A[i, j] = s
350359
b[i, 0] = s * c.bound
351360
return (A, b)
@@ -658,18 +667,18 @@ def get_pareto_frontier_and_configs(
658667
objective_weights = extract_objective_weight_matrix(
659668
objective=optimization_config.objective,
660669
outcomes=adapter.outcomes,
661-
experiment=adapter._experiment,
670+
metric_name_to_signature=adapter.metric_name_to_signature,
662671
)
663672
outcome_constraints = extract_outcome_constraints(
664673
outcome_constraints=optimization_config.outcome_constraints,
665674
outcomes=adapter.outcomes,
666-
experiment=adapter._experiment,
675+
metric_name_to_signature=adapter.metric_name_to_signature,
667676
)
668677
obj_t = extract_objective_thresholds(
669678
objective_thresholds=optimization_config.objective_thresholds,
670679
objective=optimization_config.objective,
671680
outcomes=adapter.outcomes,
672-
experiment=adapter._experiment,
681+
metric_name_to_signature=adapter.metric_name_to_signature,
673682
)
674683
if obj_t is not None:
675684
obj_t = array_to_tensor(obj_t)
@@ -1128,7 +1137,7 @@ def observation_features_to_array(
11281137
def feasible_hypervolume(
11291138
optimization_config: MultiObjectiveOptimizationConfig,
11301139
values: dict[str, npt.NDArray],
1131-
experiment: Experiment,
1140+
metric_name_to_signature: Mapping[str, str],
11321141
) -> npt.NDArray:
11331142
"""Compute the feasible hypervolume each iteration.
11341143
@@ -1137,26 +1146,26 @@ def feasible_hypervolume(
11371146
values: Dictionary from metric name to array of value at each
11381147
iteration (each array is `n`-dim). If optimization config contains
11391148
outcome constraints, values for them must be present in `values`.
1140-
experiment: The experiment, used to map metric names to signatures.
1149+
metric_name_to_signature: Mapping from metric names to signatures.
11411150
11421151
Returns: Array of feasible hypervolumes.
11431152
"""
11441153
# Get objective at each iteration
11451154
obj_threshold_dict = {
1146-
experiment.get_metric(ot.metric_names[0]).signature: ot.bound
1155+
metric_name_to_signature[ot.metric_names[0]]: ot.bound
11471156
for ot in optimization_config.objective_thresholds
11481157
}
11491158
obj_metric_names = optimization_config.objective.metric_names
1150-
obj_metrics = [experiment.get_metric(name) for name in obj_metric_names]
1151-
f_vals = np.hstack([values[m.signature].reshape(-1, 1) for m in obj_metrics])
1152-
obj_thresholds = np.array([obj_threshold_dict[m.signature] for m in obj_metrics])
1159+
obj_metric_sigs = [metric_name_to_signature[name] for name in obj_metric_names]
1160+
f_vals = np.hstack([values[sig].reshape(-1, 1) for sig in obj_metric_sigs])
1161+
obj_thresholds = np.array([obj_threshold_dict[sig] for sig in obj_metric_sigs])
11531162
# Set infeasible points to be the objective threshold
11541163
for oc in optimization_config.outcome_constraints:
11551164
if oc.relative:
11561165
raise ValueError(
11571166
"Benchmark aggregation does not support relative constraints"
11581167
)
1159-
oc_sig = experiment.get_metric(oc.metric_names[0]).signature
1168+
oc_sig = metric_name_to_signature[oc.metric_names[0]]
11601169
g = values[oc_sig]
11611170
feas = g <= oc.bound if oc.op == ComparisonOp.LEQ else g >= oc.bound
11621171
f_vals[~feas] = obj_thresholds

ax/adapter/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ def __init__(
190190
)
191191
self._experiment_properties: dict[str, Any] = experiment._properties
192192
self._experiment: Experiment = experiment
193+
self._metric_name_to_signature: dict[str, str] = {
194+
name: metric.signature for name, metric in self._experiment.metrics.items()
195+
}
193196

194197
if self._optimization_config is None:
195198
self._optimization_config = experiment.optimization_config
@@ -525,6 +528,11 @@ def metric_signatures(self) -> set[str]:
525528
"""Metric signatures present in training data."""
526529
return self._metric_signatures
527530

531+
@property
532+
def metric_name_to_signature(self) -> dict[str, str]:
533+
"""Mapping from metric names to their signatures."""
534+
return self._metric_name_to_signature
535+
528536
@property
529537
def model_space(self) -> SearchSpace:
530538
"""SearchSpace used to fit model."""

ax/adapter/cross_validation.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import annotations
1010

1111
from collections import defaultdict
12-
from collections.abc import Callable, Iterable
12+
from collections.abc import Callable, Iterable, Mapping
1313
from logging import Logger
1414
from typing import cast, NamedTuple
1515
from warnings import warn
@@ -21,7 +21,6 @@
2121
from ax.adapter.data_utils import ExperimentData
2222
from ax.adapter.observation_utils import unwrap_observation_data
2323
from ax.adapter.torch import TorchAdapter
24-
from ax.core.experiment import Experiment
2524
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2625
from ax.core.optimization_config import OptimizationConfig
2726
from ax.exceptions.core import UnsupportedError
@@ -574,7 +573,7 @@ def assess_model_fit(
574573
def has_good_opt_config_model_fit(
575574
optimization_config: OptimizationConfig,
576575
assess_model_fit_result: AssessModelFitResult,
577-
experiment: Experiment,
576+
metric_name_to_signature: Mapping[str, str],
578577
) -> bool:
579578
"""Assess model fit for given diagnostics results across the optimization
580579
config metrics
@@ -587,7 +586,7 @@ def has_good_opt_config_model_fit(
587586
Args:
588587
optimization_config: Objective/Outcome constraint metrics to assess
589588
assess_model_fit_result: Output of assess_model_fit
590-
experiment: The experiment, used to map metric names to signatures.
589+
metric_name_to_signature: Mapping from metric names to signatures.
591590
592591
Returns:
593592
Two dictionaries, one for good metrics, one for bad metrics, each
@@ -597,7 +596,7 @@ def has_good_opt_config_model_fit(
597596
# Bad fit criteria: Any objective metrics are poorly fit
598597
# TODO[]: Incl. outcome constraints in assessment
599598
has_good_opt_config_fit = all(
600-
experiment.get_metric(name).signature
599+
metric_name_to_signature[name]
601600
in assess_model_fit_result.good_fit_metrics_to_fisher_score
602601
for name in optimization_config.objective.metric_names
603602
)

ax/adapter/discrete.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,19 @@ def _gen(
164164
outcome_constraints = None
165165
else:
166166
validate_transformed_optimization_config(
167-
optimization_config, self.outcomes, experiment=self._experiment
167+
optimization_config,
168+
self.outcomes,
169+
metric_name_to_signature=self.metric_name_to_signature,
168170
)
169171
objective_weights = extract_objective_weights(
170172
objective=optimization_config.objective,
171173
outcomes=self.outcomes,
172-
experiment=self._experiment,
174+
metric_name_to_signature=self.metric_name_to_signature,
173175
)
174176
outcome_constraints = extract_outcome_constraints(
175177
outcome_constraints=optimization_config.outcome_constraints,
176178
outcomes=self.outcomes,
177-
experiment=self._experiment,
179+
metric_name_to_signature=self.metric_name_to_signature,
178180
)
179181

180182
# Get fixed features

ax/adapter/tests/test_adapter_utils.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from ax.utils.common.hash_utils import compute_lilo_input_hash
4444
from ax.utils.common.testutils import TestCase
4545
from ax.utils.testing.core_stubs import (
46-
get_branin_experiment,
4746
get_experiment_with_observations,
4847
get_hierarchical_search_space,
4948
get_search_space_for_range_values,
@@ -79,11 +78,8 @@ def test_feasible_hypervolume(self) -> None:
7978
),
8079
],
8180
)
82-
experiment = Experiment(
83-
search_space=SearchSpace(parameters=[]),
84-
optimization_config=optimization_config,
85-
tracking_metrics=[mc],
86-
)
81+
# For plain Metric objects, signature == name.
82+
metric_name_to_signature = {m.name: m.name for m in [ma, mb, mc]}
8783
feas_hv = feasible_hypervolume(
8884
optimization_config,
8985
values={
@@ -112,7 +108,7 @@ def test_feasible_hypervolume(self) -> None:
112108
]
113109
),
114110
},
115-
experiment=experiment,
111+
metric_name_to_signature=metric_name_to_signature,
116112
)
117113
self.assertEqual(list(feas_hv), [0.0, 0.0, 1.0, 1.0])
118114

@@ -544,24 +540,34 @@ def test_can_map_to_binary(self) -> None:
544540
def test_extract_objective_weight_matrix(self) -> None:
545541
m1, m2, m3 = Metric(name="m1"), Metric(name="m2"), Metric(name="m3")
546542
outcomes = ["m1", "m2", "m3"]
547-
experiment = get_branin_experiment()
548-
experiment.add_metric(m1)
549-
experiment.add_metric(m2)
550-
experiment.add_metric(m3)
543+
# For plain Metric objects, signature == name.
544+
metric_name_to_signature = {name: name for name in outcomes}
551545

552546
# Single Objective: one row, nonzero only in matching column.
553547
obj = Objective(metric=m1, minimize=False)
554-
result = extract_objective_weight_matrix(obj, outcomes, experiment)
548+
result = extract_objective_weight_matrix(
549+
objective=obj,
550+
outcomes=outcomes,
551+
metric_name_to_signature=metric_name_to_signature,
552+
)
555553
np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0]])
556554

557555
# Minimization flips the sign.
558556
obj_min = Objective(metric=m2, minimize=True)
559-
result = extract_objective_weight_matrix(obj_min, outcomes, experiment)
557+
result = extract_objective_weight_matrix(
558+
objective=obj_min,
559+
outcomes=outcomes,
560+
metric_name_to_signature=metric_name_to_signature,
561+
)
560562
np.testing.assert_array_equal(result, [[0.0, -1.0, 0.0]])
561563

562564
# ScalarizedObjective: single row with multiple nonzero entries.
563565
scal = ScalarizedObjective(metrics=[m1, m3], weights=[0.3, 0.7], minimize=False)
564-
result = extract_objective_weight_matrix(scal, outcomes, experiment)
566+
result = extract_objective_weight_matrix(
567+
objective=scal,
568+
outcomes=outcomes,
569+
metric_name_to_signature=metric_name_to_signature,
570+
)
565571
np.testing.assert_array_almost_equal(result, [[0.3, 0.0, 0.7]])
566572

567573
# MultiObjective: one row per sub-objective.
@@ -571,7 +577,11 @@ def test_extract_objective_weight_matrix(self) -> None:
571577
Objective(metric=m3, minimize=True),
572578
]
573579
)
574-
result = extract_objective_weight_matrix(multi, outcomes, experiment)
580+
result = extract_objective_weight_matrix(
581+
objective=multi,
582+
outcomes=outcomes,
583+
metric_name_to_signature=metric_name_to_signature,
584+
)
575585
np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0]])
576586

577587
def test_get_fresh_pairwise_trial_indices(self) -> None:

0 commit comments

Comments
 (0)