Skip to content

Commit c9196f2

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Change get_trace to return dict[int, float] and filter by MetricAvailability (facebook#5140)
Summary: Goal: Eliminate assumptions around COMPLETED trials always having all metrics available. This is part of the code that would've raised errors if it encountered missing metrics in COMPLETED trials. Changing the return type here seemed like a better move than using carry-forward logic, and fits better with the logic that skips status quo as well (which may lead to confusing indexing when skipping status quo). How does this new return type fit with `BatchTrial`: `get_trace` takes the best arm out of each batch, so it returns (and did so before this change) one result per trial. Nothing changed. Changes (claude summary): `get_trace` previously returned `list[float]` with positional indexing, which forced callers to fabricate sequential trial indices. This change: 1. Changes the return type to `dict[int, float]` mapping trial_index to performance value, so callers get real trial indices. 2. Adds `MetricAvailability` filtering to exclude trials with incomplete metric data before pivoting. This prevents the `ValueError("Some metrics are not present for all trials and arms")` that `_pivot_data_with_feasibility` raises when a completed trial is missing any metric (e.g., partial fetches, fetch failures, metrics added mid-experiment). 3. Removes the carry-forward expansion loop for abandoned/failed trials. These trials are now simply excluded from the returned dict. Callers updated: - `UtilityProgressionAnalysis` now shows real trial indices on x-axis - `pareto_frontier.hypervolume_trace_plot` uses real trial indices - `benchmark.py` extracts values from dict Differential Revision: D99448053
1 parent c08c4a9 commit c9196f2

6 files changed

Lines changed: 108 additions & 96 deletions

File tree

ax/analysis/plotly/tests/test_utility_progression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _assert_valid_utility_card(
3636
"""Assert that a card has valid structure for utility progression."""
3737
self.assertIsInstance(card, PlotlyAnalysisCard)
3838
self.assertEqual(card.name, "UtilityProgressionAnalysis")
39-
self.assertIn("trace_index", card.df.columns)
39+
self.assertIn("trial_index", card.df.columns)
4040
self.assertIn("utility", card.df.columns)
4141

4242
def test_utility_progression_soo(self) -> None:
@@ -211,7 +211,7 @@ def test_all_infeasible_points_raises_error(self) -> None:
211211
with (
212212
patch(
213213
"ax.analysis.plotly.utility_progression.get_trace",
214-
return_value=[math.inf, -math.inf, math.inf],
214+
return_value={0: math.inf, 1: -math.inf, 2: math.inf},
215215
),
216216
self.assertRaises(ExperimentNotReadyError) as cm,
217217
):

ax/analysis/plotly/utility_progression.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@
2828
_UTILITY_PROGRESSION_TITLE = "Utility Progression"
2929

3030
_TRACE_INDEX_EXPLANATION = (
31-
"The x-axis shows trace index, which counts completed or early-stopped trials "
32-
"sequentially (1, 2, 3, ...). This differs from trial index, which may have "
33-
"gaps if some trials failed or were abandoned. For example, if trials 0, 2, "
34-
"and 5 completed while trials 1, 3, and 4 failed, the trace indices would be "
35-
"1, 2, 3 corresponding to trial indices 0, 2, 5."
31+
"The x-axis shows trial index. Only completed or early-stopped trials with "
32+
"complete metric data are included, so there may be gaps if some trials "
33+
"failed, were abandoned, or have incomplete data."
3634
)
3735

3836
_CUMULATIVE_BEST_EXPLANATION = (
@@ -57,7 +55,8 @@ class UtilityProgressionAnalysis(Analysis):
5755
5856
The DataFrame computed will contain one row per completed trial and the
5957
following columns:
60-
- trace_index: Sequential index of completed/early-stopped trials (1, 2, 3, ...)
58+
- trial_index: The trial index of each completed/early-stopped trial
59+
that has complete metric avilability.
6160
- utility: The cumulative best utility value at that trial
6261
"""
6362

@@ -114,7 +113,7 @@ def compute(
114113
)
115114

116115
# Check if all points are infeasible (inf or -inf values)
117-
if all(np.isinf(value) for value in trace):
116+
if all(np.isinf(value) for value in trace.values()):
118117
raise ExperimentNotReadyError(
119118
"All trials in the utility trace are infeasible i.e., they violate "
120119
"outcome constraints, so there are no feasible points to plot. During "
@@ -125,12 +124,11 @@ def compute(
125124
"space, or (2) relaxing outcome constraints."
126125
)
127126

128-
# Create DataFrame with 1-based trace index for user-friendly display
129-
# (1st completed trial, 2nd completed trial, etc. instead of 0-indexed)
127+
# Create DataFrame with trial indices from the trace
130128
df = pd.DataFrame(
131129
{
132-
"trace_index": list(range(1, len(trace) + 1)),
133-
"utility": trace,
130+
"trial_index": list(trace.keys()),
131+
"utility": list(trace.values()),
134132
}
135133
)
136134

@@ -185,14 +183,14 @@ def compute(
185183
# Create the plot
186184
fig = px.line(
187185
data_frame=df,
188-
x="trace_index",
186+
x="trial_index",
189187
y="utility",
190188
markers=True,
191189
color_discrete_sequence=[AX_BLUE],
192190
)
193191

194192
# Update axis labels and format x-axis to show integers only
195-
fig.update_xaxes(title_text="Trace Index", dtick=1, rangemode="nonnegative")
193+
fig.update_xaxes(title_text="Trial Index", dtick=1, rangemode="nonnegative")
196194
fig.update_yaxes(title_text=y_label)
197195

198196
return create_plotly_analysis_card(

ax/benchmark/benchmark.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,10 @@ def _get_oracle_value_of_params(
320320
dummy_experiment = get_oracle_experiment_from_params(
321321
problem=problem, dict_of_dict_of_params={0: {"0_0": params}}
322322
)
323-
(inference_value,) = get_trace(
323+
trace = get_trace(
324324
experiment=dummy_experiment, optimization_config=problem.optimization_config
325325
)
326+
inference_value = next(iter(trace.values()))
326327
return inference_value
327328

328329

@@ -510,12 +511,27 @@ def get_benchmark_result_from_experiment_and_gs(
510511
dict_of_dict_of_params=dict_of_dict_of_params,
511512
trial_statuses=trial_statuses,
512513
)
513-
oracle_trace = np.array(
514-
get_trace(
515-
experiment=actual_params_oracle_dummy_experiment,
516-
optimization_config=problem.optimization_config,
517-
)
514+
oracle_trace_dict = get_trace(
515+
experiment=actual_params_oracle_dummy_experiment,
516+
optimization_config=problem.optimization_config,
517+
)
518+
# Expand trace dict to a positional array aligned with all trials,
519+
# carry-forwarding the last best value for trials without data (e.g.,
520+
# failed or abandoned trials preserved via trial_statuses).
521+
maximize = (
522+
isinstance(problem.optimization_config, MultiObjectiveOptimizationConfig)
523+
or problem.optimization_config.objective.is_scalarized_objective
524+
or not problem.optimization_config.objective.minimize
518525
)
526+
all_trial_indices = sorted(actual_params_oracle_dummy_experiment.trials.keys())
527+
last_best = -float("inf") if maximize else float("inf")
528+
oracle_trace_list: list[float] = []
529+
for idx in all_trial_indices:
530+
if idx in oracle_trace_dict:
531+
last_best = oracle_trace_dict[idx]
532+
oracle_trace_list.append(last_best)
533+
oracle_trace = np.array(oracle_trace_list)
534+
519535
is_feasible_trace = np.array(
520536
get_is_feasible_trace(
521537
experiment=actual_params_oracle_dummy_experiment,

ax/plot/pareto_frontier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def scatter_plot_with_hypervolume_trace_plotly(experiment: Experiment) -> go.Fig
6161

6262
df = pd.DataFrame(
6363
{
64-
"hypervolume": hypervolume_trace,
65-
"trial_index": [*range(len(hypervolume_trace))],
64+
"trial_index": list(hypervolume_trace.keys()),
65+
"hypervolume": list(hypervolume_trace.values()),
6666
}
6767
)
6868

ax/service/tests/test_best_point.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_get_trace(self) -> None:
5757
exp = get_experiment_with_observations(
5858
observations=[[11], [10], [9], [15], [5]], minimize=True
5959
)
60-
self.assertEqual(get_trace(exp), [11, 10, 9, 9, 5])
60+
self.assertEqual(get_trace(exp), {0: 11, 1: 10, 2: 9, 3: 9, 4: 5})
6161

6262
# Same experiment with maximize via new optimization config.
6363
opt_conf = none_throws(exp.optimization_config).clone()
@@ -67,7 +67,7 @@ def test_get_trace(self) -> None:
6767
opt_conf.objective.metric_names[0]: opt_conf.objective.metric_names[0]
6868
},
6969
)
70-
self.assertEqual(get_trace(exp, opt_conf), [11, 11, 11, 15, 15])
70+
self.assertEqual(get_trace(exp, opt_conf), {0: 11, 1: 11, 2: 11, 3: 15, 4: 15})
7171

7272
with self.subTest("Single objective with constraints"):
7373
# The second metric is the constraint and needs to be >= 0
@@ -76,48 +76,52 @@ def test_get_trace(self) -> None:
7676
minimize=False,
7777
constrained=True,
7878
)
79-
self.assertEqual(get_trace(exp), [float("-inf"), 10, 10, 10, 11])
79+
self.assertEqual(
80+
get_trace(exp),
81+
{0: float("-inf"), 1: 10, 2: 10, 3: 10, 4: 11},
82+
)
8083

8184
exp = get_experiment_with_observations(
8285
observations=[[11, -1], [10, 1], [9, 1], [15, -1], [11, 1]],
8386
minimize=True,
8487
constrained=True,
8588
)
86-
self.assertEqual(get_trace(exp), [float("inf"), 10, 9, 9, 9])
89+
self.assertEqual(get_trace(exp), {0: float("inf"), 1: 10, 2: 9, 3: 9, 4: 9})
8790

8891
# Scalarized.
8992
exp = get_experiment_with_observations(
9093
observations=[[1, 1], [2, 2], [3, 3]],
9194
scalarized=True,
9295
)
93-
self.assertEqual(get_trace(exp), [2, 4, 6])
96+
self.assertEqual(get_trace(exp), {0: 2, 1: 4, 2: 6})
9497

9598
# Multi objective.
9699
exp = get_experiment_with_observations(
97100
observations=[[1, 1], [-1, 100], [1, 2], [3, 3], [2, 4], [2, 1]],
98101
)
99-
self.assertEqual(get_trace(exp), [1, 1, 2, 9, 11, 11])
102+
self.assertEqual(get_trace(exp), {0: 1, 1: 1, 2: 2, 3: 9, 4: 11, 5: 11})
100103

101104
# W/o ObjectiveThresholds (inferring ObjectiveThresholds from scaled nadir)
102105
assert_is_instance(
103106
exp.optimization_config, MultiObjectiveOptimizationConfig
104107
).objective_thresholds = []
105108
trace = get_trace(exp)
109+
trace_values = list(trace.values())
106110
# With inferred thresholds via scaled nadir, check trace properties:
107111
# - All values should be non-negative
108-
self.assertTrue(all(v >= 0.0 for v in trace))
112+
self.assertTrue(all(v >= 0.0 for v in trace_values))
109113
# - Trace should be non-decreasing (cumulative best)
110-
for i in range(1, len(trace)):
111-
self.assertGreaterEqual(trace[i], trace[i - 1])
114+
for i in range(1, len(trace_values)):
115+
self.assertGreaterEqual(trace_values[i], trace_values[i - 1])
112116
# - Final value should be positive (non-trivial HV)
113-
self.assertGreater(trace[-1], 0.0)
117+
self.assertGreater(trace_values[-1], 0.0)
114118

115119
# Multi-objective w/ constraints.
116120
exp = get_experiment_with_observations(
117121
observations=[[-1, 1, 1], [1, 2, 1], [3, 3, -1], [2, 4, 1], [2, 1, 1]],
118122
constrained=True,
119123
)
120-
self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])
124+
self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8})
121125

122126
# W/ relative constraints & status quo.
123127
exp.status_quo = Arm(parameters={"x": 0.5, "y": 0.5}, name="status_quo")
@@ -149,17 +153,17 @@ def test_get_trace(self) -> None:
149153
]
150154
status_quo_data = Data(df=pd.DataFrame.from_records(df_dict))
151155
exp.attach_data(data=status_quo_data)
152-
self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])
156+
self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8})
153157

154158
# W/ first objective being minimized.
155159
exp = get_experiment_with_observations(
156160
observations=[[1, 1], [-1, 2], [3, 3], [-2, 4], [2, 1]], minimize=True
157161
)
158-
self.assertEqual(get_trace(exp), [0, 2, 2, 8, 8])
162+
self.assertEqual(get_trace(exp), {0: 0, 1: 2, 2: 2, 3: 8, 4: 8})
159163

160164
# W/ empty data.
161165
exp = get_experiment_with_trial()
162-
self.assertEqual(get_trace(exp), [])
166+
self.assertEqual(get_trace(exp), {})
163167

164168
# test batch trial
165169
exp = get_experiment_with_batch_trial(with_status_quo=False)
@@ -191,7 +195,7 @@ def test_get_trace(self) -> None:
191195
]
192196
)
193197
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict)))
194-
self.assertEqual(get_trace(exp), [2.0])
198+
self.assertEqual(get_trace(exp), {0: 2.0})
195199
# test that there is performance metric in the trace for each
196200
# completed/early-stopped trial
197201
trial1 = assert_is_instance(trial, BatchTrial).clone_to(include_sq=False)
@@ -214,7 +218,7 @@ def test_get_trace(self) -> None:
214218
]
215219
)
216220
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2)))
217-
self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0])
221+
self.assertEqual(get_trace(exp), {0: 2.0, 2: 20.0})
218222

219223
def test_get_trace_with_non_completed_trials(self) -> None:
220224
with self.subTest("minimize with abandoned trial"):
@@ -224,12 +228,11 @@ def test_get_trace_with_non_completed_trials(self) -> None:
224228
# Mark trial 2 (value=9) as abandoned
225229
exp.trials[2].mark_abandoned(unsafe=True)
226230

227-
# Abandoned trial carries forward the last best value
231+
# Abandoned trial is excluded from trace
228232
trace = get_trace(exp)
229-
self.assertEqual(len(trace), 5)
230-
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10
233+
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): excluded
231234
# Trial 3: 10 (15 > 10), Trial 4: 5
232-
self.assertEqual(trace, [11, 10, 10, 10, 5])
235+
self.assertEqual(trace, {0: 11, 1: 10, 3: 10, 4: 5})
233236

234237
with self.subTest("maximize with abandoned trial"):
235238
exp = get_experiment_with_observations(
@@ -238,12 +241,11 @@ def test_get_trace_with_non_completed_trials(self) -> None:
238241
# Mark trial 1 (value=3) as abandoned
239242
exp.trials[1].mark_abandoned(unsafe=True)
240243

241-
# Abandoned trial carries forward the last best value
244+
# Abandoned trial is excluded from trace
242245
trace = get_trace(exp)
243-
self.assertEqual(len(trace), 5)
244-
# Trial 0: 1, Trial 1 (abandoned): carry forward 1,
246+
# Trial 0: 1, Trial 1 (abandoned): excluded,
245247
# Trial 2: 2, Trial 3: 5, Trial 4: 5
246-
self.assertEqual(trace, [1, 1, 2, 5, 5])
248+
self.assertEqual(trace, {0: 1, 2: 2, 3: 5, 4: 5})
247249

248250
with self.subTest("minimize with failed trial"):
249251
exp = get_experiment_with_observations(
@@ -252,12 +254,11 @@ def test_get_trace_with_non_completed_trials(self) -> None:
252254
# Mark trial 2 (value=9) as failed
253255
exp.trials[2].mark_failed(unsafe=True)
254256

255-
# Failed trial carries forward the last best value
257+
# Failed trial is excluded from trace
256258
trace = get_trace(exp)
257-
self.assertEqual(len(trace), 5)
258-
# Trial 0: 11, Trial 1: 10, Trial 2 (failed): carry forward 10
259+
# Trial 0: 11, Trial 1: 10, Trial 2 (failed): excluded
259260
# Trial 3: 10 (15 > 10), Trial 4: 5
260-
self.assertEqual(trace, [11, 10, 10, 10, 5])
261+
self.assertEqual(trace, {0: 11, 1: 10, 3: 10, 4: 5})
261262

262263
def test_get_trace_with_include_status_quo(self) -> None:
263264
with self.subTest("Multi-objective: status quo dominates in some trials"):
@@ -347,9 +348,11 @@ def test_get_trace_with_include_status_quo(self) -> None:
347348
# The last value MUST differ because status quo dominates
348349
# Without status quo, only poor arms contribute (low hypervolume)
349350
# With status quo, excellent values contribute (high hypervolume)
351+
last_without = list(trace_without_sq.values())[-1]
352+
last_with = list(trace_with_sq.values())[-1]
350353
self.assertGreater(
351-
trace_with_sq[-1],
352-
trace_without_sq[-1],
354+
last_with,
355+
last_without,
353356
f"Status quo dominates in trial 3, so trace with SQ should be higher. "
354357
f"Without SQ: {trace_without_sq}, With SQ: {trace_with_sq}",
355358
)
@@ -418,9 +421,11 @@ def test_get_trace_with_include_status_quo(self) -> None:
418421
# The last value MUST differ because status quo is best
419422
# Without status quo: best in trial 3 is 15.0, cumulative min is 9
420423
# With status quo: best in trial 3 is 5.0, cumulative min is 5
424+
last_without = list(trace_without_sq.values())[-1]
425+
last_with = list(trace_with_sq.values())[-1]
421426
self.assertLess(
422-
trace_with_sq[-1],
423-
trace_without_sq[-1],
427+
last_with,
428+
last_without,
424429
f"Status quo is best in trial 3, so trace with SQ should be "
425430
f"lower (minimize). Without SQ: {trace_without_sq}, "
426431
f"With SQ: {trace_with_sq}",
@@ -502,19 +507,20 @@ def _make_pref_opt_config(self, profile_name: str) -> PreferenceOptimizationConf
502507
preference_profile_name=profile_name,
503508
)
504509

505-
def _assert_valid_trace(self, trace: list[float], expected_len: int) -> None:
510+
def _assert_valid_trace(self, trace: dict[int, float], expected_len: int) -> None:
506511
"""Assert trace has expected length, contains floats, is non-decreasing and has
507512
more than one unique value."""
508513
self.assertEqual(len(trace), expected_len)
509-
for value in trace:
514+
trace_values = list(trace.values())
515+
for value in trace_values:
510516
self.assertIsInstance(value, float)
511-
for i in range(1, len(trace)):
517+
for i in range(1, len(trace_values)):
512518
self.assertGreaterEqual(
513-
trace[i],
514-
trace[i - 1],
519+
trace_values[i],
520+
trace_values[i - 1],
515521
msg=f"Trace not monotonically increasing at index {i}: {trace}",
516522
)
517-
unique_values = set(trace)
523+
unique_values = set(trace_values)
518524
self.assertGreater(
519525
len(unique_values),
520526
1,

0 commit comments

Comments
 (0)