Skip to content

Commit 4e345f4

Browse files
ltiaometa-codesync[bot]
authored andcommitted
Refactor get_trace and extend get_opt_trace_by_steps to MOO/constrained (#4884)
Summary: Pull Request resolved: #4884 We have a method `get_opt_trace_by_steps` that was used extensively during our Ax 1.0 benchmarking campaign. It duplicates the basic logic of `get_trace` but differs in that it operates along `(trial_index, MAP_KEY)` pairs and respects ordering by timestamp (i.e. chronological order). However, it is limited to single-objective unconstrained problems, and our current needs (multi-objective and/or constrained) have outgrown it. We reconcile the two by extracting three core building blocks of `get_trace`: 1. `_pivot_data_with_feasibility`: Pivots data to wide format with feasibility information and metric completeness checks. 2. `_compute_trace_values`: Computes per-observation trace values (hypervolume for MOO, objective value for SOO), with cumulative best support. 3. `_aggregate_and_cumulate_trace`: Aggregates values by groups and computes the cumulative best across groups. These are implemented in a more general way that respects arbitrary groupings and orderings. We then refactor `get_trace` (and its helpers `_prepare_data_for_trace` and `get_trace_by_arm_pull_from_data`) to use these building blocks, and leverage them in `get_opt_trace_by_steps` to extend its support to multi-objective and constrained problems. Additionally: - The timestamp-based sorting in `get_opt_trace_by_steps` is preserved, which is critical for correct cumulative hypervolume computation (without this, observations would be processed in `(trial_index, arm_name, MAP_KEY)` order instead of chronological order). - Tests are updated to replace `NotImplementedError` checks with actual MOO and constrained test cases that verify correctness of the new functionality. Reviewed By: dme65 Differential Revision: D79581270 fbshipit-source-id: 3ffa69c2d6fab6928de6b12a59e3cdec2325a191
1 parent 2c8dbe5 commit 4e345f4

4 files changed

Lines changed: 348 additions & 129 deletions

File tree

ax/benchmark/benchmark.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@
5454
from ax.generation_strategy.generation_strategy import GenerationStrategy
5555
from ax.orchestration.orchestrator import Orchestrator
5656
from ax.service.utils.best_point import (
57+
_aggregate_and_cumulate_trace,
58+
_compute_trace_values,
59+
_pivot_data_with_feasibility,
5760
_prepare_data_for_trace,
5861
derelativize_opt_config,
5962
get_trace,
63+
is_row_feasible,
6064
)
6165
from ax.service.utils.best_point_mixin import BestPointMixin
6266
from ax.service.utils.orchestrator_options import OrchestratorOptions, TrialType
@@ -796,64 +800,80 @@ def get_opt_trace_by_steps(experiment: Experiment) -> npt.NDArray:
796800
that is in terms of steps, with one element added each time a step
797801
completes.
798802
803+
Supports single-objective, multi-objective, and constrained problems.
804+
For multi-objective problems, the trace is in terms of hypervolume.
805+
799806
Args:
800807
experiment: An experiment produced by `benchmark_replication`; it must
801808
have `BenchmarkTrialMetadata` (as produced by `BenchmarkRunner`) for
802809
each trial, and its data must have a "step" column.
803810
"""
804811
optimization_config = none_throws(experiment.optimization_config)
812+
full_df = experiment.lookup_data().full_df
805813

806-
if optimization_config.is_moo_problem:
807-
raise NotImplementedError(
808-
"Cumulative epochs only supported for single objective problems."
809-
)
810-
if len(optimization_config.outcome_constraints) > 0:
811-
raise NotImplementedError(
812-
"Cumulative epochs not supported for problems with outcome constraints."
813-
)
814+
full_df["row_feasible"] = is_row_feasible(
815+
df=full_df,
816+
optimization_config=optimization_config,
817+
# For the sake of this function, we only care about feasible trials. The
818+
# distinction between infeasible and undetermined is not important.
819+
undetermined_value=False,
820+
)
814821

815-
objective_name = optimization_config.objective.metric.name
816-
data = experiment.lookup_data()
817-
full_df = data.full_df
822+
# Pivot to wide format with feasibility
823+
df_wide = _pivot_data_with_feasibility(
824+
df=full_df,
825+
index=["trial_index", "arm_name", MAP_KEY],
826+
optimization_config=optimization_config,
827+
)
818828

819-
# Has timestamps; needs to be merged with full_df because it contains
820-
# data on epochs that didn't actually run due to early stopping, and we need
821-
# to know which actually ran
822-
def _get_df(trial: Trial) -> pd.DataFrame:
829+
def _get_timestamps(experiment: Experiment) -> pd.Series:
823830
"""
824-
Get the (virtual) time each epoch finished at.
831+
Get the (virtual) time at which each training progression finished.
825832
"""
826-
metadata = trial.run_metadata["benchmark_metadata"]
827-
backend_simulator = none_throws(metadata.backend_simulator)
828-
# Data for the first metric, which is the only metric
829-
df = next(iter(metadata.dfs.values()))
830-
start_time = backend_simulator.get_sim_trial_by_index(
831-
trial.index
832-
).sim_start_time
833-
df["time"] = df["virtual runtime"] + start_time
834-
return df
835-
836-
with_timestamps = pd.concat(
837-
(
838-
_get_df(trial=assert_is_instance(trial, Trial))
839-
for trial in experiment.trials.values()
840-
),
841-
axis=0,
842-
ignore_index=True,
843-
)[["trial_index", MAP_KEY, "time"]]
844-
845-
df = (
846-
full_df.loc[
847-
full_df["metric_name"] == objective_name,
848-
["trial_index", "arm_name", "mean", MAP_KEY],
849-
]
850-
.merge(with_timestamps, how="left")
851-
.sort_values("time", ignore_index=True)
833+
frames = []
834+
for trial in experiment.trials.values():
835+
trial = assert_is_instance(trial, Trial)
836+
metadata = trial.run_metadata["benchmark_metadata"]
837+
backend_simulator = none_throws(metadata.backend_simulator)
838+
sim_trial = backend_simulator.get_sim_trial_by_index(
839+
trial_index=trial.index
840+
)
841+
start_time = sim_trial.sim_start_time
842+
# timestamps are identical across all metrics, so just use the first one
843+
frame = next(iter(metadata.dfs.values())).copy()
844+
frame["time"] = frame["virtual runtime"] + start_time
845+
frames.append(frame)
846+
df = pd.concat(frames, axis=0, ignore_index=True).set_index(
847+
["trial_index", "arm_name", MAP_KEY]
848+
)
849+
return df["time"]
850+
851+
# Compute timestamps and join with df_wide *before* cumulative computations.
852+
# This is critical because cumulative HV/objective calculations depend on
853+
# the temporal ordering of observations.
854+
timestamps = _get_timestamps(experiment=experiment)
855+
856+
# Merge timestamps and sort by time before cumulative computations
857+
df_wide = df_wide.join(
858+
timestamps, on=["trial_index", "arm_name", MAP_KEY], how="left"
859+
).sort_values(by="time", ascending=True, ignore_index=True)
860+
861+
# Compute per-evaluation (trial_index, MAP_KEY) cumulative values,
862+
# with keep_order=True to preserve ordering by timestamp
863+
df_wide["value"], maximize = _compute_trace_values(
864+
df_wide=df_wide,
865+
optimization_config=optimization_config,
866+
use_cumulative_best=True,
852867
)
853-
return (
854-
df["mean"].cummin()
855-
if optimization_config.objective.minimize
856-
else df["mean"].cummax()
868+
# Get a value for each (trial_index, arm_name, MAP_KEY) tuple
869+
value_by_arm_pull = df_wide[["trial_index", "arm_name", MAP_KEY, "value"]]
870+
871+
# Aggregate by trial and step, then compute cumulative best
872+
return _aggregate_and_cumulate_trace(
873+
df=value_by_arm_pull,
874+
by=["trial_index", MAP_KEY],
875+
maximize=maximize,
876+
keep_order=True,
857877
).to_numpy()
858878

859879

@@ -872,15 +892,16 @@ def get_benchmark_result_with_cumulative_steps(
872892
opt_trace = get_opt_trace_by_steps(experiment=experiment)
873893
return replace(
874894
result,
875-
optimization_trace=opt_trace,
876-
cost_trace=np.arange(1, len(opt_trace) + 1, dtype=int),
895+
optimization_trace=opt_trace.tolist(),
896+
cost_trace=np.arange(1, len(opt_trace) + 1, dtype=int).tolist(),
877897
num_trials=list(range(1, len(opt_trace) + 1)),
878898
# Empty
879-
oracle_trace=np.full(len(opt_trace), np.nan),
880-
inference_trace=np.full(len(opt_trace), np.nan),
899+
oracle_trace=np.full_like(opt_trace, np.nan).tolist(),
900+
inference_trace=np.full_like(opt_trace, np.nan).tolist(),
901+
is_feasible_trace=None,
881902
score_trace=compute_score_trace(
882903
optimization_trace=opt_trace,
883904
baseline_value=baseline_value,
884905
optimal_value=optimal_value,
885-
),
906+
).tolist(),
886907
)

ax/benchmark/testing/benchmark_stubs.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,52 @@ def get_async_benchmark_problem(
313313
n_steps: int = 1,
314314
lower_is_better: bool = False,
315315
report_inference_value_as_trace: bool = False,
316+
num_objectives: int = 1,
317+
num_constraints: int = 0,
316318
) -> BenchmarkProblem:
319+
"""
320+
Create an early-stopping benchmark problem with MAP_KEY data.
321+
322+
Args:
323+
map_data: Whether to use map metrics (required for early stopping).
324+
step_runtime_fn: Optional runtime function for steps.
325+
n_steps: Number of steps per trial.
326+
lower_is_better: Whether lower values are better (for SOO).
327+
report_inference_value_as_trace: Whether to report inference trace.
328+
num_objectives: Number of objectives (1 for SOO, >1 for MOO).
329+
num_constraints: Number of outcome constraints to add.
330+
331+
Returns:
332+
A BenchmarkProblem suitable for early-stopping evaluation.
333+
"""
317334
search_space = get_discrete_search_space()
318-
test_function = IdentityTestFunction(n_steps=n_steps)
319-
optimization_config = get_soo_opt_config(
320-
outcome_names=["objective"],
321-
use_map_metric=map_data,
322-
observe_noise_sd=True,
323-
lower_is_better=lower_is_better,
324-
)
335+
336+
# Create outcome names for objectives and constraints
337+
objective_names = [f"objective_{i}" for i in range(num_objectives)]
338+
constraint_names = [f"constraint_{i}" for i in range(num_constraints)]
339+
outcome_names = [*objective_names, *constraint_names]
340+
341+
test_function = IdentityTestFunction(n_steps=n_steps, outcome_names=outcome_names)
342+
343+
if num_objectives == 1:
344+
# Single-objective: first outcome is objective, rest are constraints
345+
optimization_config = get_soo_opt_config(
346+
outcome_names=outcome_names,
347+
lower_is_better=lower_is_better,
348+
observe_noise_sd=True,
349+
use_map_metric=map_data,
350+
)
351+
else:
352+
# Multi-objective: pass all outcomes (objectives + constraints)
353+
# get_moo_opt_config will use the last num_constraints as constraints
354+
optimization_config = get_moo_opt_config(
355+
outcome_names=outcome_names,
356+
ref_point=[1.0] * num_objectives,
357+
num_constraints=num_constraints,
358+
lower_is_better=lower_is_better,
359+
observe_noise_sd=True,
360+
use_map_metric=map_data,
361+
)
325362

326363
return BenchmarkProblem(
327364
name="test",
@@ -331,6 +368,9 @@ def get_async_benchmark_problem(
331368
num_trials=4,
332369
baseline_value=19 if lower_is_better else 0,
333370
optimal_value=0 if lower_is_better else 19,
371+
worst_feasible_value=(19 if lower_is_better else 0)
372+
if num_constraints > 0
373+
else None,
334374
step_runtime_function=step_runtime_fn,
335375
report_inference_value_as_trace=report_inference_value_as_trace,
336376
)

ax/benchmark/tests/test_benchmark.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,28 +1214,85 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None:
12141214
new_opt_trace = get_opt_trace_by_steps(experiment=experiment)
12151215
self.assertEqual(list(new_opt_trace), [0.0, 0.0, 1.0, 1.0, 2.0, 3.0])
12161216

1217-
method = get_sobol_benchmark_method()
1218-
with self.subTest("MOO"):
1219-
problem = get_multi_objective_benchmark_problem()
1220-
1217+
with self.subTest("Multi-objective"):
1218+
# Multi-objective problem with step data
1219+
problem = get_async_benchmark_problem(
1220+
map_data=True,
1221+
n_steps=5,
1222+
num_objectives=2,
1223+
# Ensure we don't have two finishing at the same time, for
1224+
# determinism
1225+
step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]),
1226+
)
12211227
experiment = self.run_optimization_with_orchestrator(
12221228
problem=problem, method=method, seed=0
12231229
)
1224-
with self.assertRaisesRegex(
1225-
NotImplementedError, "only supported for single objective"
1226-
):
1227-
get_opt_trace_by_steps(experiment=experiment)
1230+
new_opt_trace = get_opt_trace_by_steps(experiment=experiment)
1231+
self.assertListEqual(
1232+
new_opt_trace.tolist(),
1233+
[
1234+
0.0,
1235+
0.0,
1236+
0.0,
1237+
0.0,
1238+
0.0,
1239+
0.0,
1240+
0.0,
1241+
1.0,
1242+
1.0,
1243+
1.0,
1244+
1.0,
1245+
1.0,
1246+
1.0,
1247+
4.0,
1248+
4.0,
1249+
4.0,
1250+
4.0,
1251+
4.0,
1252+
4.0,
1253+
4.0,
1254+
],
1255+
)
12281256

12291257
with self.subTest("Constrained"):
1230-
problem = get_benchmark_problem("constrained_gramacy_observed_noise")
1258+
# Constrained problem with step data.
1259+
problem = get_async_benchmark_problem(
1260+
map_data=True,
1261+
n_steps=5,
1262+
num_constraints=1,
1263+
# Ensure we don't have two finishing at the same time, for
1264+
# determinism
1265+
step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]),
1266+
)
12311267
experiment = self.run_optimization_with_orchestrator(
12321268
problem=problem, method=method, seed=0
12331269
)
1234-
with self.assertRaisesRegex(
1235-
NotImplementedError,
1236-
"not supported for problems with outcome constraints",
1237-
):
1238-
get_opt_trace_by_steps(experiment=experiment)
1270+
new_opt_trace = get_opt_trace_by_steps(experiment=experiment)
1271+
self.assertListEqual(
1272+
new_opt_trace.tolist(),
1273+
[
1274+
0.0,
1275+
0.0,
1276+
0.0,
1277+
0.0,
1278+
0.0,
1279+
1.0,
1280+
1.0,
1281+
2.0,
1282+
2.0,
1283+
2.0,
1284+
2.0,
1285+
2.0,
1286+
2.0,
1287+
3.0,
1288+
3.0,
1289+
3.0,
1290+
3.0,
1291+
3.0,
1292+
3.0,
1293+
3.0,
1294+
],
1295+
)
12391296

12401297
def test_get_benchmark_result_with_cumulative_steps(self) -> None:
12411298
"""See test_get_opt_trace_by_cumulative_epochs for more info."""

0 commit comments

Comments
 (0)