Skip to content

Commit 8ac7d52

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Change get_trace to return dict[int, float] and filter by MetricAvailability
Summary: `get_trace` previously returned `list[float]` with positional indexing, which forced callers to fabricate sequential trial indices. This change: 1. Changes the return type to `dict[int, float]` mapping trial_index to performance value, so callers get real trial indices. 2. Adds `MetricAvailability` filtering to exclude trials with incomplete metric data before pivoting. This prevents the `ValueError("Some metrics are not present for all trials and arms")` that `_pivot_data_with_feasibility` raises when a completed trial is missing any metric (e.g., partial fetches, fetch failures, metrics added mid-experiment). 3. Removes the carry-forward expansion loop for abandoned/failed trials. These trials are now simply excluded from the returned dict. Callers updated: - `UtilityProgressionAnalysis` now shows real trial indices on x-axis - `pareto_frontier.hypervolume_trace_plot` uses real trial indices - `benchmark.py` extracts values from dict Differential Revision: D99448053
1 parent c08c4a9 commit 8ac7d52

6 files changed

Lines changed: 108 additions & 96 deletions

File tree

ax/analysis/plotly/tests/test_utility_progression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _assert_valid_utility_card(
3636
"""Assert that a card has valid structure for utility progression."""
3737
self.assertIsInstance(card, PlotlyAnalysisCard)
3838
self.assertEqual(card.name, "UtilityProgressionAnalysis")
39-
self.assertIn("trace_index", card.df.columns)
39+
self.assertIn("trial_index", card.df.columns)
4040
self.assertIn("utility", card.df.columns)
4141

4242
def test_utility_progression_soo(self) -> None:
@@ -211,7 +211,7 @@ def test_all_infeasible_points_raises_error(self) -> None:
211211
with (
212212
patch(
213213
"ax.analysis.plotly.utility_progression.get_trace",
214-
return_value=[math.inf, -math.inf, math.inf],
214+
return_value={0: math.inf, 1: -math.inf, 2: math.inf},
215215
),
216216
self.assertRaises(ExperimentNotReadyError) as cm,
217217
):

ax/analysis/plotly/utility_progression.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@
2828
_UTILITY_PROGRESSION_TITLE = "Utility Progression"
2929

3030
_TRACE_INDEX_EXPLANATION = (
31-
"The x-axis shows trace index, which counts completed or early-stopped trials "
32-
"sequentially (1, 2, 3, ...). This differs from trial index, which may have "
33-
"gaps if some trials failed or were abandoned. For example, if trials 0, 2, "
34-
"and 5 completed while trials 1, 3, and 4 failed, the trace indices would be "
35-
"1, 2, 3 corresponding to trial indices 0, 2, 5."
31+
"The x-axis shows trial index. Only completed or early-stopped trials with "
32+
"complete metric data are included, so there may be gaps if some trials "
33+
"failed, were abandoned, or have incomplete data."
3634
)
3735

3836
_CUMULATIVE_BEST_EXPLANATION = (
@@ -57,7 +55,8 @@ class UtilityProgressionAnalysis(Analysis):
5755
5856
The DataFrame computed will contain one row per completed trial and the
5957
following columns:
60-
- trace_index: Sequential index of completed/early-stopped trials (1, 2, 3, ...)
58+
- trial_index: The trial index of each completed/early-stopped trial
59+
that has complete metric avilability.
6160
- utility: The cumulative best utility value at that trial
6261
"""
6362

@@ -114,7 +113,7 @@ def compute(
114113
)
115114

116115
# Check if all points are infeasible (inf or -inf values)
117-
if all(np.isinf(value) for value in trace):
116+
if all(np.isinf(value) for value in trace.values()):
118117
raise ExperimentNotReadyError(
119118
"All trials in the utility trace are infeasible i.e., they violate "
120119
"outcome constraints, so there are no feasible points to plot. During "
@@ -125,12 +124,11 @@ def compute(
125124
"space, or (2) relaxing outcome constraints."
126125
)
127126

128-
# Create DataFrame with 1-based trace index for user-friendly display
129-
# (1st completed trial, 2nd completed trial, etc. instead of 0-indexed)
127+
# Create DataFrame with trial indices from the trace
130128
df = pd.DataFrame(
131129
{
132-
"trace_index": list(range(1, len(trace) + 1)),
133-
"utility": trace,
130+
"trial_index": list(trace.keys()),
131+
"utility": list(trace.values()),
134132
}
135133
)
136134

@@ -185,14 +183,14 @@ def compute(
185183
# Create the plot
186184
fig = px.line(
187185
data_frame=df,
188-
x="trace_index",
186+
x="trial_index",
189187
y="utility",
190188
markers=True,
191189
color_discrete_sequence=[AX_BLUE],
192190
)
193191

194192
# Update axis labels and format x-axis to show integers only
195-
fig.update_xaxes(title_text="Trace Index", dtick=1, rangemode="nonnegative")
193+
fig.update_xaxes(title_text="Trial Index", dtick=1, rangemode="nonnegative")
196194
fig.update_yaxes(title_text=y_label)
197195

198196
return create_plotly_analysis_card(

ax/benchmark/benchmark.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,10 @@ def _get_oracle_value_of_params(
320320
dummy_experiment = get_oracle_experiment_from_params(
321321
problem=problem, dict_of_dict_of_params={0: {"0_0": params}}
322322
)
323-
(inference_value,) = get_trace(
323+
trace = get_trace(
324324
experiment=dummy_experiment, optimization_config=problem.optimization_config
325325
)
326+
inference_value = next(iter(trace.values()))
326327
return inference_value
327328

328329

@@ -510,12 +511,27 @@ def get_benchmark_result_from_experiment_and_gs(
510511
dict_of_dict_of_params=dict_of_dict_of_params,
511512
trial_statuses=trial_statuses,
512513
)
513-
oracle_trace = np.array(
514-
get_trace(
515-
experiment=actual_params_oracle_dummy_experiment,
516-
optimization_config=problem.optimization_config,
517-
)
514+
oracle_trace_dict = get_trace(
515+
experiment=actual_params_oracle_dummy_experiment,
516+
optimization_config=problem.optimization_config,
517+
)
518+
# Expand trace dict to a positional array aligned with all trials,
519+
# carry-forwarding the last best value for trials without data (e.g.,
520+
# failed or abandoned trials preserved via trial_statuses).
521+
maximize = (
522+
isinstance(problem.optimization_config, MultiObjectiveOptimizationConfig)
523+
or problem.optimization_config.objective.is_scalarized_objective
524+
or not problem.optimization_config.objective.minimize
518525
)
526+
all_trial_indices = sorted(actual_params_oracle_dummy_experiment.trials.keys())
527+
last_best = -float("inf") if maximize else float("inf")
528+
oracle_trace_list: list[float] = []
529+
for idx in all_trial_indices:
530+
if idx in oracle_trace_dict:
531+
last_best = oracle_trace_dict[idx]
532+
oracle_trace_list.append(last_best)
533+
oracle_trace = np.array(oracle_trace_list)
534+
519535
is_feasible_trace = np.array(
520536
get_is_feasible_trace(
521537
experiment=actual_params_oracle_dummy_experiment,

ax/plot/pareto_frontier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def scatter_plot_with_hypervolume_trace_plotly(experiment: Experiment) -> go.Fig
6161

6262
df = pd.DataFrame(
6363
{
64-
"hypervolume": hypervolume_trace,
65-
"trial_index": [*range(len(hypervolume_trace))],
64+
"trial_index": list(hypervolume_trace.keys()),
65+
"hypervolume": list(hypervolume_trace.values()),
6666
}
6767
)
6868

ax/service/tests/test_best_point.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_get_trace(self) -> None:
5757
exp = get_experiment_with_observations(
5858
observations=[[11], [10], [9], [15], [5]], minimize=True
5959
)
60-
self.assertEqual(get_trace(exp), [11, 10, 9, 9, 5])
60+
self.assertEqual(get_trace(exp), {0: 11, 1: 10, 2: 9, 3: 9, 4: 5})
6161

6262
# Same experiment with maximize via new optimization config.
6363
opt_conf = none_throws(exp.optimization_config).clone()
@@ -67,7 +67,7 @@ def test_get_trace(self) -> None:
6767
opt_conf.objective.metric_names[0]: opt_conf.objective.metric_names[0]
6868
},
6969
)
70-
self.assertEqual(get_trace(exp, opt_conf), [11, 11, 11, 15, 15])
70+
self.assertEqual(get_trace(exp, opt_conf), {0: 11, 1: 11, 2: 11, 3: 15, 4: 15})
7171

7272
with self.subTest("Single objective with constraints"):
7373
# The second metric is the constraint and needs to be >= 0
@@ -76,48 +76,52 @@ def test_get_trace(self) -> None:
7676
minimize=False,
7777
constrained=True,
7878
)
79-
self.assertEqual(get_trace(exp), [float("-inf"), 10, 10, 10, 11])
79+
self.assertEqual(
80+
get_trace(exp),
81+
{0: float("-inf"), 1: 10, 2: 10, 3: 10, 4: 11},
82+
)
8083

8184
exp = get_experiment_with_observations(
8285
observations=[[11, -1], [10, 1], [9, 1], [15, -1], [11, 1]],
8386
minimize=True,
8487
constrained=True,
8588
)
86-
self.assertEqual(get_trace(exp), [float("inf"), 10, 9, 9, 9])
89+
self.assertEqual(get_trace(exp), {0: float("inf"), 1: 10, 2: 9, 3: 9, 4: 9})
8790

8891
# Scalarized.
8992
exp = get_experiment_with_observations(
9093
observations=[[1, 1], [2, 2], [3, 3]],
9194
scalarized=True,
9295
)
93-
self.assertEqual(get_trace(exp), [2, 4, 6])
96+
self.assertEqual(get_trace(exp), {0: 2, 1: 4, 2: 6})
9497

9598
# Multi objective.
9699
exp = get_experiment_with_observations(
97100
observations=[[1, 1], [-1, 100], [1, 2], [3, 3], [2, 4], [2, 1]],
98101
)
99-
self.assertEqual(get_trace(exp), [1, 1, 2, 9, 11, 11])
102+
self.assertEqual(get_trace(exp), {0: 1, 1: 1, 2: 2, 3: 9, 4: 11, 5: 11})
100103

101104
# W/o ObjectiveThresholds (inferring ObjectiveThresholds from scaled nadir)
102105
assert_is_instance(
103106
exp.optimization_config, MultiObjectiveOptimizationConfig
104107
).objective_thresholds = []
105108
trace = get_trace(exp)
109+
trace_values = list(trace.values())
106110
# With inferred thresholds via scaled nadir, check trace properties:
107111
# - All values should be non-negative
108-
self.assertTrue(all(v >= 0.0 for v in trace))
112+
self.assertTrue(all(v >= 0.0 for v in trace_values))
109113
# - Trace should be non-decreasing (cumulative best)
110-
for i in range(1, len(trace)):
111-
self.assertGreaterEqual(trace[i], trace[i - 1])
114+
for i in range(1, len(trace_values)):
115+
self.assertGreaterEqual(trace_values[i], trace_values[i - 1])
112116
# - Final value should be positive (non-trivial HV)
113-
self.assertGreater(trace[-1], 0.0)
117+
self.assertGreater(trace_values[-1], 0.0)
114118

115119
# Multi-objective w/ constraints.
116120
exp = get_experiment_with_observations(
117121
observations=[[-1, 1, 1], [1, 2, 1], [3, 3, -1], [2, 4, 1], [2, 1, 1]],
118122
constrained=True,
119123
)
120-
self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])
124+
self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8})
121125

122126
# W/ relative constraints & status quo.
123127
exp.status_quo = Arm(parameters={"x": 0.5, "y": 0.5}, name="status_quo")
@@ -149,17 +153,17 @@ def test_get_trace(self) -> None:
149153
]
150154
status_quo_data = Data(df=pd.DataFrame.from_records(df_dict))
151155
exp.attach_data(data=status_quo_data)
152-
self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])
156+
self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8})
153157

154158
# W/ first objective being minimized.
155159
exp = get_experiment_with_observations(
156160
observations=[[1, 1], [-1, 2], [3, 3], [-2, 4], [2, 1]], minimize=True
157161
)
158-
self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])
162+
self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8})
159163

160164
# W/ empty data.
161165
exp = get_experiment_with_trial()
162-
self.assertEqual(get_trace(exp), [])
166+
self.assertEqual(get_trace(exp), {})
163167

164168
# test batch trial
165169
exp = get_experiment_with_batch_trial(with_status_quo=False)
@@ -191,7 +195,7 @@ def test_get_trace(self) -> None:
191195
]
192196
)
193197
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict)))
194-
self.assertEqual(get_trace(exp), [2.0])
198+
self.assertEqual(get_trace(exp), {0: 2.0})
195199
# test that there is performance metric in the trace for each
196200
# completed/early-stopped trial
197201
trial1 = assert_is_instance(trial, BatchTrial).clone_to(include_sq=False)
@@ -214,7 +218,7 @@ def test_get_trace(self) -> None:
214218
]
215219
)
216220
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2)))
217-
self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0])
221+
self.assertEqual(get_trace(exp), {0: 2.0, 2: 20.0})
218222

219223
def test_get_trace_with_non_completed_trials(self) -> None:
220224
with self.subTest("minimize with abandoned trial"):
@@ -224,12 +228,11 @@ def test_get_trace_with_non_completed_trials(self) -> None:
224228
# Mark trial 2 (value=9) as abandoned
225229
exp.trials[2].mark_abandoned(unsafe=True)
226230

227-
# Abandoned trial carries forward the last best value
231+
# Abandoned trial is excluded from trace
228232
trace = get_trace(exp)
229-
self.assertEqual(len(trace), 5)
230-
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10
233+
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): excluded
231234
# Trial 3: 10 (15 > 10), Trial 4: 5
232-
self.assertEqual(trace, [11, 10, 10, 10, 5])
235+
self.assertEqual(trace, {0: 11, 1: 10, 3: 10, 4: 5})
233236

234237
with self.subTest("maximize with abandoned trial"):
235238
exp = get_experiment_with_observations(
@@ -238,12 +241,11 @@ def test_get_trace_with_non_completed_trials(self) -> None:
238241
# Mark trial 1 (value=3) as abandoned
239242
exp.trials[1].mark_abandoned(unsafe=True)
240243

241-
# Abandoned trial carries forward the last best value
244+
# Abandoned trial is excluded from trace
242245
trace = get_trace(exp)
243-
self.assertEqual(len(trace), 5)
244-
# Trial 0: 1, Trial 1 (abandoned): carry forward 1,
246+
# Trial 0: 1, Trial 1 (abandoned): excluded,
245247
# Trial 2: 2, Trial 3: 5, Trial 4: 5
246-
self.assertEqual(trace, [1, 1, 2, 5, 5])
248+
self.assertEqual(trace, {0: 1, 2: 2, 3: 5, 4: 5})
247249

248250
with self.subTest("minimize with failed trial"):
249251
exp = get_experiment_with_observations(
@@ -252,12 +254,11 @@ def test_get_trace_with_non_completed_trials(self) -> None:
252254
# Mark trial 2 (value=9) as failed
253255
exp.trials[2].mark_failed(unsafe=True)
254256

255-
# Failed trial carries forward the last best value
257+
# Failed trial is excluded from trace
256258
trace = get_trace(exp)
257-
self.assertEqual(len(trace), 5)
258-
# Trial 0: 11, Trial 1: 10, Trial 2 (failed): carry forward 10
259+
# Trial 0: 11, Trial 1: 10, Trial 2 (failed): excluded
259260
# Trial 3: 10 (15 > 10), Trial 4: 5
260-
self.assertEqual(trace, [11, 10, 10, 10, 5])
261+
self.assertEqual(trace, {0: 11, 1: 10, 3: 10, 4: 5})
261262

262263
def test_get_trace_with_include_status_quo(self) -> None:
263264
with self.subTest("Multi-objective: status quo dominates in some trials"):
@@ -347,9 +348,11 @@ def test_get_trace_with_include_status_quo(self) -> None:
347348
# The last value MUST differ because status quo dominates
348349
# Without status quo, only poor arms contribute (low hypervolume)
349350
# With status quo, excellent values contribute (high hypervolume)
351+
last_without = list(trace_without_sq.values())[-1]
352+
last_with = list(trace_with_sq.values())[-1]
350353
self.assertGreater(
351-
trace_with_sq[-1],
352-
trace_without_sq[-1],
354+
last_with,
355+
last_without,
353356
f"Status quo dominates in trial 3, so trace with SQ should be higher. "
354357
f"Without SQ: {trace_without_sq}, With SQ: {trace_with_sq}",
355358
)
@@ -418,9 +421,11 @@ def test_get_trace_with_include_status_quo(self) -> None:
418421
# The last value MUST differ because status quo is best
419422
# Without status quo: best in trial 3 is 15.0, cumulative min is 9
420423
# With status quo: best in trial 3 is 5.0, cumulative min is 5
424+
last_without = list(trace_without_sq.values())[-1]
425+
last_with = list(trace_with_sq.values())[-1]
421426
self.assertLess(
422-
trace_with_sq[-1],
423-
trace_without_sq[-1],
427+
last_with,
428+
last_without,
424429
f"Status quo is best in trial 3, so trace with SQ should be "
425430
f"lower (minimize). Without SQ: {trace_without_sq}, "
426431
f"With SQ: {trace_with_sq}",
@@ -502,19 +507,20 @@ def _make_pref_opt_config(self, profile_name: str) -> PreferenceOptimizationConf
502507
preference_profile_name=profile_name,
503508
)
504509

505-
def _assert_valid_trace(self, trace: list[float], expected_len: int) -> None:
510+
def _assert_valid_trace(self, trace: dict[int, float], expected_len: int) -> None:
506511
"""Assert trace has expected length, contains floats, is non-decreasing and has
507512
more than one unique value."""
508513
self.assertEqual(len(trace), expected_len)
509-
for value in trace:
514+
trace_values = list(trace.values())
515+
for value in trace_values:
510516
self.assertIsInstance(value, float)
511-
for i in range(1, len(trace)):
517+
for i in range(1, len(trace_values)):
512518
self.assertGreaterEqual(
513-
trace[i],
514-
trace[i - 1],
519+
trace_values[i],
520+
trace_values[i - 1],
515521
msg=f"Trace not monotonically increasing at index {i}: {trace}",
516522
)
517-
unique_values = set(trace)
523+
unique_values = set(trace_values)
518524
self.assertGreater(
519525
len(unique_values),
520526
1,

0 commit comments

Comments
 (0)