Skip to content

Commit fb5afd8

Browse files
ltiaofacebook-github-bot
authored andcommitted
Clean up align_partial_results, eliminate double alignment, and remove Data round-trip
Summary: Profiling-driven cleanup and performance improvements to the early stopping pipeline. **1. Remove redundant validations from `align_partial_results` (utils.py)** `align_partial_results` contained three validation/logging blocks that are redundant with upstream checks already performed by `_lookup_and_validate` in `base.py`: - **Missing metrics check + `raise ValueError` (lines 166-168):** `_lookup_and_validate` already verifies each metric signature exists in the data (base.py lines 234-240) and returns `None` before `align_partial_results` is ever called. - **Per-metric logging loop (lines 172-180):** The "no data" branch is unreachable because the upstream checks guarantee data exists for each metric. The debug logging about MAP_KEY ranges is misplaced for a pure data-alignment function. - **Trial-to-arm uniqueness check (lines 186-198):** Redundant with `is_eligible_any` (base.py lines 344-352) which already rejects `BatchTrial` -- the only way a trial gets multiple arms. The arm-to-trial uniqueness check was moved (not removed) to `_lookup_and_validate`, where it properly guards all downstream consumers, not just alignment. The `isin` filter (`df = df[df['metric_signature'].isin(metrics)]`) was kept -- it is part of `align_partial_results`'s own contract since the function accepts a `metrics` argument and callers depend on it filtering to those metrics. After cleanup, `align_partial_results` is a focused alignment function: filter to requested metrics -> drop `arm_name` -> drop duplicates -> sort -> pivot -> interpolate. **2. Eliminate double data alignment in `PercentileEarlyStoppingStrategy` (percentile.py)** When `check_safe=True`, the base class `should_stop_trials_early` calls `_is_harmful()` which calls `_prepare_aligned_frames()`, then `_should_stop_trials_early()` calls `_prepare_aligned_frames()` again -- identical data lookup + alignment running twice. Fix: override `should_stop_trials_early` to call `_default_objective_and_direction()` and `_prepare_aligned_frames()` once, passing results as explicit keyword arguments (`metric_signature`, `minimize`, `aligned_frames`) to both `_is_harmful` and `_should_stop_trials_early`. Each defaults to `None` and falls back to computing from scratch when not provided, preserving backward compatibility for subclasses. **3. Eliminate the `Data` round-trip in `_lookup_and_validate` (base.py, threshold.py)** `_lookup_and_validate` (formerly `_lookup_and_validate_data`) used to wrap its filtered DataFrame back into `Data(df=filtered_df)` at the end, purely to satisfy its `Data | None` return type. The sole consumer (`_prepare_aligned_frames`) immediately called `.full_df` to get the DataFrame back. This triggered a pointless df-to-DataRow-to-df round-trip on every call: ``` DataFrame -> itertuples -> list[DataRow] -> from_records -> regex sort -> cast -> DataFrame ``` Profiling shows this round-trip is **100% overhead** at every scale: | Scale | Rows | Round-trip cost | Replacement (df.copy) | |-----------------|---------|----------------|-----------------------| | tiny (5x10) | 50 | 4ms | <0.1ms | | typical (20x100)| 2,000 | 15ms | <0.1ms | | large (50x200) | 10,000 | 60ms | <0.1ms | | xlarge (100x200)| 20,000 | 160ms | <0.1ms | | huge (200x500) | 100,000 | 733ms | <0.1ms | The cost comes from `Data.__init__` iterating every row via `itertuples()` to build `list[DataRow]` (~35ms at 10k rows), then `Data.full_df` reconstructing the DataFrame via `from_records()` (~12ms) and running regex-based arm name parsing in `sort_by_trial_index_and_arm_name()` (~19ms) -- none of which is needed since we already had the DataFrame. Fix: change `_lookup_and_validate` to return `pd.DataFrame | None` directly. Update `ModelBasedEarlyStoppingStrategy` and `ThresholdEarlyStoppingStrategy` accordingly. End-to-end profiling of `should_stop_trials_early` at the 50x200 scale (10k rows) shows a ~19% speedup (323ms -> ~263ms), with the benefit growing at larger scales (up to ~41% at 100k rows). **4. Naming cleanup** Renamed methods and variables to disambiguate between `Data` objects and `pd.DataFrame`: - `_lookup_and_validate_data` -> `_lookup_and_validate` (returns `pd.DataFrame | None`) - `_prepare_aligned_data` -> `_prepare_aligned_frames` (returns a tuple of DataFrames) - Variables holding `_lookup_and_validate` results: `data`/`map_data`/`data_lookup` -> `df` - Parameter `aligned_data` -> `aligned_frames` **5. Tests** - Removed tests for moved/removed `align_partial_results` validations. - Added test for the arm-to-trial uniqueness check in `_lookup_and_validate`. - Added profiling notebook at `ax/early_stopping/profiling.ipynb`. Differential Revision: D98544835
1 parent f5976b8 commit fb5afd8

5 files changed

Lines changed: 206 additions & 172 deletions

File tree

ax/early_stopping/strategies/base.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pandas as pd
1616
from ax.adapter.data_utils import _maybe_normalize_map_key
1717
from ax.core.batch_trial import BatchTrial
18-
from ax.core.data import Data, MAP_KEY
18+
from ax.core.data import MAP_KEY
1919
from ax.core.experiment import Experiment
2020
from ax.core.trial_status import TrialStatus
2121
from ax.early_stopping.utils import (
@@ -217,13 +217,10 @@ def estimate_early_stopping_savings(self, experiment: Experiment) -> float:
217217

218218
return estimate_early_stopping_savings(experiment=experiment)
219219

220-
def _lookup_and_validate_data(
220+
def _lookup_and_validate(
221221
self, experiment: Experiment, metric_signatures: list[str]
222-
) -> Data | None:
223-
"""Looks up and validates the `Data` used for early stopping that
224-
is associated with `metric_signatures`. This function also handles normalizing
225-
progressions.
226-
"""
222+
) -> pd.DataFrame | None:
223+
"""Look up and validate experiment data for early stopping."""
227224
data = experiment.lookup_data()
228225
if data.df.empty:
229226
logger.info(
@@ -250,6 +247,17 @@ def _lookup_and_validate_data(
250247
full_df = data.full_df
251248
full_df = full_df[full_df["metric_signature"].isin(metric_signatures)]
252249

250+
# Check that no arm name appears across multiple trials.
251+
# This can happen with duplicate arm parameterizations that reuse arm
252+
# names across trials, which would corrupt the alignment step.
253+
arm_trial_counts = full_df.groupby("arm_name")["trial_index"].nunique()
254+
bad_arms = arm_trial_counts[arm_trial_counts > 1]
255+
if len(bad_arms) > 0:
256+
raise UnsupportedError(
257+
f"Arm(s) {bad_arms.index.tolist()} appear across multiple "
258+
f"trial indices. Each arm name must map to exactly one trial."
259+
)
260+
253261
# Drop rows with NaN values in MAP_KEY column to prevent issues in
254262
# align_partial_results which uses MAP_KEY as the pivot index
255263
nan_mask = full_df[MAP_KEY].isna()
@@ -264,7 +272,7 @@ def _lookup_and_validate_data(
264272

265273
if self.normalize_progressions:
266274
full_df = _maybe_normalize_map_key(df=full_df)
267-
return Data(df=full_df)
275+
return full_df
268276

269277
@staticmethod
270278
def _log_and_return_no_data(
@@ -547,7 +555,7 @@ def _all_objectives_and_directions(self, experiment: Experiment) -> dict[str, bo
547555

548556
return directions
549557

550-
def _prepare_aligned_data(
558+
def _prepare_aligned_frames(
551559
self, experiment: Experiment, metric_signatures: list[str]
552560
) -> tuple[pd.DataFrame, pd.DataFrame] | None:
553561
"""Get raw experiment data and align it for early stopping evaluation.
@@ -564,15 +572,15 @@ def _prepare_aligned_data(
564572
with first level ["mean", "sem"] and second level metric signatures
565573
Returns None if data cannot be retrieved or aligned.
566574
"""
567-
data = self._lookup_and_validate_data(
575+
long_df = self._lookup_and_validate(
568576
experiment=experiment, metric_signatures=metric_signatures
569577
)
570-
if data is None:
578+
if long_df is None:
571579
return None
572580

573581
try:
574582
multilevel_wide_df = align_partial_results(
575-
df=(long_df := data.full_df),
583+
df=long_df,
576584
metrics=metric_signatures,
577585
)
578586
except Exception as e:
@@ -651,18 +659,15 @@ def __init__(
651659
)
652660
self.min_progression_modeling = min_progression_modeling
653661

654-
def _lookup_and_validate_data(
662+
def _lookup_and_validate(
655663
self, experiment: Experiment, metric_signatures: list[str]
656-
) -> Data | None:
657-
"""Looks up and validates the `Data` used for early stopping that
658-
is associated with `metric_signatures`. This function also handles normalizing
659-
progressions.
664+
) -> pd.DataFrame | None:
665+
"""Look up and validate experiment data for early stopping, applying
666+
``min_progression_modeling`` filter if configured.
660667
"""
661-
map_data = super()._lookup_and_validate_data(
668+
df = super()._lookup_and_validate(
662669
experiment=experiment, metric_signatures=metric_signatures
663670
)
664-
if map_data is not None and self.min_progression_modeling is not None:
665-
full_df = map_data.full_df
666-
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
667-
map_data = Data(df=full_df)
668-
return map_data
671+
if df is not None and self.min_progression_modeling is not None:
672+
df = df[df[MAP_KEY] >= self.min_progression_modeling]
673+
return df

ax/early_stopping/strategies/percentile.py

Lines changed: 92 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,63 @@ def __init__(
124124
"with multiple metrics."
125125
)
126126

127+
def should_stop_trials_early(
128+
self,
129+
trial_indices: set[int],
130+
experiment: Experiment,
131+
current_node: GenerationNode | None = None,
132+
) -> dict[int, str | None]:
133+
"""Decide whether trials should be stopped before evaluation is fully concluded.
134+
135+
Overrides the base class to compute aligned data once and reuse it for
136+
both the safety check (``_is_harmful``) and the stopping decision
137+
(``_should_stop_trials_early``), avoiding redundant data lookups and
138+
alignment when ``check_safe=True``.
139+
140+
Args:
141+
trial_indices: Indices of candidate trials to stop early.
142+
experiment: Experiment that contains the trials and other contextual data.
143+
current_node: The current ``GenerationNode`` on the ``GenerationStrategy``
144+
used to generate trials for the ``Experiment``.
145+
146+
Returns:
147+
A dictionary mapping trial indices that should be early stopped to
148+
(optional) messages with the associated reason.
149+
"""
150+
metric_signature, minimize = self._default_objective_and_direction(
151+
experiment=experiment
152+
)
153+
aligned_frames = self._prepare_aligned_frames(
154+
experiment=experiment, metric_signatures=[metric_signature]
155+
)
156+
if aligned_frames is None:
157+
return {}
158+
159+
if self.check_safe and self._is_harmful(
160+
trial_indices=trial_indices,
161+
experiment=experiment,
162+
metric_signature=metric_signature,
163+
minimize=minimize,
164+
aligned_frames=aligned_frames,
165+
):
166+
return {}
167+
168+
return self._should_stop_trials_early(
169+
trial_indices=trial_indices,
170+
experiment=experiment,
171+
current_node=current_node,
172+
metric_signature=metric_signature,
173+
minimize=minimize,
174+
aligned_frames=aligned_frames,
175+
)
176+
127177
def _is_harmful(
128178
self,
129179
trial_indices: set[int],
130180
experiment: Experiment,
181+
metric_signature: str | None = None,
182+
minimize: bool | None = None,
183+
aligned_frames: tuple[pd.DataFrame, pd.DataFrame] | None = None,
131184
) -> bool:
132185
"""Check if the early stopping strategy would stop the globally best trial.
133186
@@ -139,21 +192,30 @@ def _is_harmful(
139192
Args:
140193
trial_indices: Set of trial indices being evaluated (ignored).
141194
experiment: Experiment that contains the trials and other contextual data.
195+
metric_signature: The metric signature to evaluate. If not provided,
196+
it is inferred from the experiment's objective.
197+
minimize: Whether the metric is being minimized. If not provided,
198+
it is inferred from the experiment's objective.
199+
aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple
200+
from ``_prepare_aligned_frames``. If not provided, it is computed
201+
from scratch.
142202
143203
Returns:
144204
True if the strategy would have stopped the globally best trial,
145205
False otherwise.
146206
"""
147-
metric_signature, minimize = self._default_objective_and_direction(
148-
experiment=experiment
149-
)
150-
maybe_aligned_dataframes = self._prepare_aligned_data(
151-
experiment=experiment, metric_signatures=[metric_signature]
152-
)
153-
if maybe_aligned_dataframes is None:
154-
return False
207+
if metric_signature is None or minimize is None:
208+
metric_signature, minimize = self._default_objective_and_direction(
209+
experiment=experiment
210+
)
211+
if aligned_frames is None:
212+
aligned_frames = self._prepare_aligned_frames(
213+
experiment=experiment, metric_signatures=[metric_signature]
214+
)
215+
if aligned_frames is None:
216+
return False
155217

156-
long_df, multilevel_wide_df = maybe_aligned_dataframes
218+
long_df, multilevel_wide_df = aligned_frames
157219
wide_df = multilevel_wide_df["mean"][metric_signature]
158220

159221
# Get completed trials
@@ -180,6 +242,9 @@ def _should_stop_trials_early(
180242
trial_indices: set[int],
181243
experiment: Experiment,
182244
current_node: GenerationNode | None = None,
245+
metric_signature: str | None = None,
246+
minimize: bool | None = None,
247+
aligned_frames: tuple[pd.DataFrame, pd.DataFrame] | None = None,
183248
) -> dict[int, str | None]:
184249
"""Stop a trial if its performance is in the bottom `percentile_threshold`
185250
of the trials at the same step.
@@ -191,22 +256,31 @@ def _should_stop_trials_early(
191256
used to generate trials for the ``Experiment``. Early stopping
192257
strategies may utilize components of the current node when making
193258
stopping decisions.
259+
metric_signature: The metric signature to evaluate. If not provided,
260+
it is inferred from the experiment's objective.
261+
minimize: Whether the metric is being minimized. If not provided,
262+
it is inferred from the experiment's objective.
263+
aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple
264+
from ``_prepare_aligned_frames``. If not provided, it is computed
265+
from scratch.
194266
195267
Returns:
196268
A dictionary mapping trial indices that should be early stopped to
197269
(optional) messages with the associated reason. An empty dictionary
198270
means no suggested updates to any trial's status.
199271
"""
200-
metric_signature, minimize = self._default_objective_and_direction(
201-
experiment=experiment
202-
)
203-
maybe_aligned_dataframes = self._prepare_aligned_data(
204-
experiment=experiment, metric_signatures=[metric_signature]
205-
)
206-
if maybe_aligned_dataframes is None:
207-
return {}
272+
if metric_signature is None or minimize is None:
273+
metric_signature, minimize = self._default_objective_and_direction(
274+
experiment=experiment
275+
)
276+
if aligned_frames is None:
277+
aligned_frames = self._prepare_aligned_frames(
278+
experiment=experiment, metric_signatures=[metric_signature]
279+
)
280+
if aligned_frames is None:
281+
return {}
208282

209-
long_df, multilevel_wide_df = maybe_aligned_dataframes
283+
long_df, multilevel_wide_df = aligned_frames
210284
wide_df = multilevel_wide_df["mean"][metric_signature]
211285

212286
# default checks on `min_progression` and `min_curves`; if not met, don't do

ax/early_stopping/strategies/threshold.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,13 @@ def _should_stop_trials_early(
112112
metric_signature, minimize = self._default_objective_and_direction(
113113
experiment=experiment
114114
)
115-
data = self._lookup_and_validate_data(
115+
df = self._lookup_and_validate(
116116
experiment=experiment, metric_signatures=[metric_signature]
117117
)
118-
if data is None:
118+
if df is None:
119119
# don't stop any trials if we don't get data back
120120
return {}
121121

122-
df = data.full_df
123-
124122
# default checks on `min_progression` and `min_curves`; if not met, don't do
125123
# early stopping at all and return {}
126124
if not self.is_eligible_any(

0 commit comments

Comments
 (0)