Skip to content

Commit c6a49ee

Browse files
ltiaofacebook-github-bot
authored andcommitted
Clean up align_partial_results, eliminate double alignment, and remove Data round-trip
Summary: Profiling-driven cleanup and performance improvements to the early stopping pipeline. **1. Remove redundant validations from `align_partial_results` (utils.py)** `align_partial_results` contained three validation/logging blocks that are redundant with upstream checks already performed by `_lookup_and_validate` in `base.py`: - **Missing metrics check + `raise ValueError`:** `_lookup_and_validate` already verifies each metric signature exists in the data and returns `None` before `align_partial_results` is ever called. - **Per-metric logging loop:** The "no data" branch is unreachable because the upstream checks guarantee data exists for each metric. Debug logging about MAP_KEY ranges is misplaced for a pure alignment function. - **Trial-to-arm uniqueness check:** Redundant with `is_eligible_any` which already rejects `BatchTrial` -- the only way a trial gets multiple arms. The arm-to-trial uniqueness check was moved (not removed) to `_lookup_and_validate`, where it properly guards all downstream consumers, not just alignment. The `isin` filter (`df = df[df['metric_signature'].isin(metrics)]`) was kept -- it is part of `align_partial_results`'s own contract since the function accepts a `metrics` argument and callers depend on it filtering to those metrics. After cleanup, `align_partial_results` is a focused alignment function: filter to requested metrics -> drop `arm_name` -> drop duplicates -> sort -> pivot -> interpolate. **2. Eliminate double data alignment in `PercentileEarlyStoppingStrategy` (percentile.py)** When `check_safe=True`, the base class `should_stop_trials_early` calls `_is_harmful()` which calls `_prepare_aligned_frames()`, then `_should_stop_trials_early()` calls `_prepare_aligned_frames()` again -- identical data lookup + alignment running twice. Fix: override `should_stop_trials_early` to call `_default_objective_and_direction()` and `_prepare_aligned_frames()` once, passing results as required arguments (`metric_signature`, `minimize`, `aligned_frames`) to both `_is_harmful` and `_should_stop_trials_early`. **3. Eliminate the `Data` round-trip in `_lookup_and_validate` (base.py, threshold.py)** `_lookup_and_validate` (formerly `_lookup_and_validate_data`) used to wrap its filtered DataFrame back into `Data(df=filtered_df)` at the end, purely to satisfy its `Data | None` return type. The sole consumer (`_prepare_aligned_frames`) immediately called `.full_df` to get the DataFrame back. This triggered a pointless df-to-DataRow-to-df round-trip on every call: ``` DataFrame -> itertuples -> list[DataRow] -> from_records -> regex sort -> cast -> DataFrame ``` Profiling shows this round-trip is **100% overhead** at every scale: | Scale | Rows | Round-trip cost | Replacement (df.copy) | |-----------------|---------|----------------|-----------------------| | tiny (5x10) | 50 | 4ms | <0.1ms | | typical (20x100)| 2,000 | 15ms | <0.1ms | | large (50x200) | 10,000 | 60ms | <0.1ms | | xlarge (100x200)| 20,000 | 160ms | <0.1ms | | huge (200x500) | 100,000 | 733ms | <0.1ms | The cost comes from `Data.__init__` iterating every row via `itertuples()` to build `list[DataRow]` (~35ms at 10k rows), then `Data.full_df` reconstructing the DataFrame via `from_records()` (~12ms) and running regex-based arm name parsing in `sort_by_trial_index_and_arm_name()` (~19ms) -- none of which is needed since we already had the DataFrame. Fix: change `_lookup_and_validate` to return `pd.DataFrame | None` directly. Update `ModelBasedEarlyStoppingStrategy` and `ThresholdEarlyStoppingStrategy` accordingly. End-to-end profiling of `should_stop_trials_early` at the 50x200 scale (10k rows) shows a ~19% speedup (323ms -> ~263ms), with the benefit growing at larger scales (up to ~41% at 100k rows). **4. Naming cleanup** Renamed methods and variables to disambiguate between `Data` objects and `pd.DataFrame`: - `_lookup_and_validate_data` -> `_lookup_and_validate` (returns `pd.DataFrame | None`) - `_prepare_aligned_data` -> `_prepare_aligned_frames` (returns a tuple of DataFrames) - Variables holding `_lookup_and_validate` results: `data`/`map_data`/`data_lookup` -> `df` - Parameter `aligned_data` -> `aligned_frames` **5. Tests** - Removed tests for moved/removed `align_partial_results` validations. - Added arm-to-trial uniqueness check as a subtest in `test_is_eligible`. - Added profiling notebook at `ax/early_stopping/profiling.ipynb`. Reviewed By: saitcakmak Differential Revision: D98544835
1 parent c08c4a9 commit c6a49ee

5 files changed

Lines changed: 180 additions & 174 deletions

File tree

ax/early_stopping/strategies/base.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pandas as pd
1616
from ax.adapter.data_utils import _maybe_normalize_map_key
1717
from ax.core.batch_trial import BatchTrial
18-
from ax.core.data import Data, MAP_KEY
18+
from ax.core.data import MAP_KEY
1919
from ax.core.experiment import Experiment
2020
from ax.core.trial_status import TrialStatus
2121
from ax.early_stopping.utils import (
@@ -217,13 +217,10 @@ def estimate_early_stopping_savings(self, experiment: Experiment) -> float:
217217

218218
return estimate_early_stopping_savings(experiment=experiment)
219219

220-
def _lookup_and_validate_data(
220+
def _lookup_and_validate(
221221
self, experiment: Experiment, metric_signatures: list[str]
222-
) -> Data | None:
223-
"""Looks up and validates the `Data` used for early stopping that
224-
is associated with `metric_signatures`. This function also handles normalizing
225-
progressions.
226-
"""
222+
) -> pd.DataFrame | None:
223+
"""Look up and validate experiment data for early stopping."""
227224
data = experiment.lookup_data()
228225
if data.df.empty:
229226
logger.info(
@@ -250,6 +247,17 @@ def _lookup_and_validate_data(
250247
full_df = data.full_df
251248
full_df = full_df[full_df["metric_signature"].isin(metric_signatures)]
252249

250+
# Check that no arm name appears across multiple trials.
251+
# This can happen with duplicate arm parameterizations that reuse arm
252+
# names across trials, which would corrupt the alignment step.
253+
arm_trial_counts = full_df.groupby("arm_name")["trial_index"].nunique()
254+
bad_arms = arm_trial_counts[arm_trial_counts > 1]
255+
if len(bad_arms) > 0:
256+
raise UnsupportedError(
257+
f"Arm(s) {bad_arms.index.tolist()} appear across multiple "
258+
f"trial indices. Each arm name must map to exactly one trial."
259+
)
260+
253261
# Drop rows with NaN values in MAP_KEY column to prevent issues in
254262
# align_partial_results which uses MAP_KEY as the pivot index
255263
nan_mask = full_df[MAP_KEY].isna()
@@ -264,7 +272,7 @@ def _lookup_and_validate_data(
264272

265273
if self.normalize_progressions:
266274
full_df = _maybe_normalize_map_key(df=full_df)
267-
return Data(df=full_df)
275+
return full_df
268276

269277
@staticmethod
270278
def _log_and_return_no_data(
@@ -547,7 +555,7 @@ def _all_objectives_and_directions(self, experiment: Experiment) -> dict[str, bo
547555

548556
return directions
549557

550-
def _prepare_aligned_data(
558+
def _prepare_aligned_frames(
551559
self, experiment: Experiment, metric_signatures: list[str]
552560
) -> tuple[pd.DataFrame, pd.DataFrame] | None:
553561
"""Get raw experiment data and align it for early stopping evaluation.
@@ -564,15 +572,15 @@ def _prepare_aligned_data(
564572
with first level ["mean", "sem"] and second level metric signatures
565573
Returns None if data cannot be retrieved or aligned.
566574
"""
567-
data = self._lookup_and_validate_data(
575+
long_df = self._lookup_and_validate(
568576
experiment=experiment, metric_signatures=metric_signatures
569577
)
570-
if data is None:
578+
if long_df is None:
571579
return None
572580

573581
try:
574582
multilevel_wide_df = align_partial_results(
575-
df=(long_df := data.full_df),
583+
df=long_df,
576584
metrics=metric_signatures,
577585
)
578586
except Exception as e:
@@ -651,18 +659,15 @@ def __init__(
651659
)
652660
self.min_progression_modeling = min_progression_modeling
653661

654-
def _lookup_and_validate_data(
662+
def _lookup_and_validate(
655663
self, experiment: Experiment, metric_signatures: list[str]
656-
) -> Data | None:
657-
"""Looks up and validates the `Data` used for early stopping that
658-
is associated with `metric_signatures`. This function also handles normalizing
659-
progressions.
664+
) -> pd.DataFrame | None:
665+
"""Look up and validate experiment data for early stopping, applying
666+
``min_progression_modeling`` filter if configured.
660667
"""
661-
map_data = super()._lookup_and_validate_data(
668+
df = super()._lookup_and_validate(
662669
experiment=experiment, metric_signatures=metric_signatures
663670
)
664-
if map_data is not None and self.min_progression_modeling is not None:
665-
full_df = map_data.full_df
666-
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
667-
map_data = Data(df=full_df)
668-
return map_data
671+
if df is not None and self.min_progression_modeling is not None:
672+
df = df[df[MAP_KEY] >= self.min_progression_modeling]
673+
return df

ax/early_stopping/strategies/percentile.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,63 @@ def __init__(
124124
"with multiple metrics."
125125
)
126126

127+
def should_stop_trials_early(
128+
self,
129+
trial_indices: set[int],
130+
experiment: Experiment,
131+
current_node: GenerationNode | None = None,
132+
) -> dict[int, str | None]:
133+
"""Decide whether trials should be stopped before evaluation is fully concluded.
134+
135+
Overrides the base class to compute aligned data once and reuse it for
136+
both the safety check (``_is_harmful``) and the stopping decision
137+
(``_should_stop_trials_early``), avoiding redundant data lookups and
138+
alignment when ``check_safe=True``.
139+
140+
Args:
141+
trial_indices: Indices of candidate trials to stop early.
142+
experiment: Experiment that contains the trials and other contextual data.
143+
current_node: The current ``GenerationNode`` on the ``GenerationStrategy``
144+
used to generate trials for the ``Experiment``.
145+
146+
Returns:
147+
A dictionary mapping trial indices that should be early stopped to
148+
(optional) messages with the associated reason.
149+
"""
150+
metric_signature, minimize = self._default_objective_and_direction(
151+
experiment=experiment
152+
)
153+
aligned_frames = self._prepare_aligned_frames(
154+
experiment=experiment, metric_signatures=[metric_signature]
155+
)
156+
if aligned_frames is None:
157+
return {}
158+
159+
if self.check_safe and self._is_harmful(
160+
trial_indices=trial_indices,
161+
experiment=experiment,
162+
metric_signature=metric_signature,
163+
minimize=minimize,
164+
aligned_frames=aligned_frames,
165+
):
166+
return {}
167+
168+
return self._should_stop_trials_early(
169+
trial_indices=trial_indices,
170+
experiment=experiment,
171+
current_node=current_node,
172+
metric_signature=metric_signature,
173+
minimize=minimize,
174+
aligned_frames=aligned_frames,
175+
)
176+
127177
def _is_harmful(
128178
self,
129179
trial_indices: set[int],
130180
experiment: Experiment,
181+
metric_signature: str,
182+
minimize: bool,
183+
aligned_frames: tuple[pd.DataFrame, pd.DataFrame],
131184
) -> bool:
132185
"""Check if the early stopping strategy would stop the globally best trial.
133186
@@ -139,21 +192,16 @@ def _is_harmful(
139192
Args:
140193
trial_indices: Set of trial indices being evaluated (ignored).
141194
experiment: Experiment that contains the trials and other contextual data.
195+
metric_signature: The metric signature to evaluate.
196+
minimize: Whether the metric is being minimized.
197+
aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple
198+
from ``_prepare_aligned_frames``.
142199
143200
Returns:
144201
True if the strategy would have stopped the globally best trial,
145202
False otherwise.
146203
"""
147-
metric_signature, minimize = self._default_objective_and_direction(
148-
experiment=experiment
149-
)
150-
maybe_aligned_dataframes = self._prepare_aligned_data(
151-
experiment=experiment, metric_signatures=[metric_signature]
152-
)
153-
if maybe_aligned_dataframes is None:
154-
return False
155-
156-
long_df, multilevel_wide_df = maybe_aligned_dataframes
204+
long_df, multilevel_wide_df = aligned_frames
157205
wide_df = multilevel_wide_df["mean"][metric_signature]
158206

159207
# Get completed trials
@@ -179,6 +227,9 @@ def _should_stop_trials_early(
179227
self,
180228
trial_indices: set[int],
181229
experiment: Experiment,
230+
metric_signature: str,
231+
minimize: bool,
232+
aligned_frames: tuple[pd.DataFrame, pd.DataFrame],
182233
current_node: GenerationNode | None = None,
183234
) -> dict[int, str | None]:
184235
"""Stop a trial if its performance is in the bottom `percentile_threshold`
@@ -187,6 +238,10 @@ def _should_stop_trials_early(
187238
Args:
188239
trial_indices: Indices of candidate trials to consider for early stopping.
189240
experiment: Experiment that contains the trials and other contextual data.
241+
metric_signature: The metric signature to evaluate.
242+
minimize: Whether the metric is being minimized.
243+
aligned_frames: Pre-computed ``(long_df, multilevel_wide_df)`` tuple
244+
from ``_prepare_aligned_frames``.
190245
current_node: The current ``GenerationNode`` on the ``GenerationStrategy``
191246
used to generate trials for the ``Experiment``. Early stopping
192247
strategies may utilize components of the current node when making
@@ -197,16 +252,7 @@ def _should_stop_trials_early(
197252
(optional) messages with the associated reason. An empty dictionary
198253
means no suggested updates to any trial's status.
199254
"""
200-
metric_signature, minimize = self._default_objective_and_direction(
201-
experiment=experiment
202-
)
203-
maybe_aligned_dataframes = self._prepare_aligned_data(
204-
experiment=experiment, metric_signatures=[metric_signature]
205-
)
206-
if maybe_aligned_dataframes is None:
207-
return {}
208-
209-
long_df, multilevel_wide_df = maybe_aligned_dataframes
255+
long_df, multilevel_wide_df = aligned_frames
210256
wide_df = multilevel_wide_df["mean"][metric_signature]
211257

212258
# default checks on `min_progression` and `min_curves`; if not met, don't do

ax/early_stopping/strategies/threshold.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,13 @@ def _should_stop_trials_early(
112112
metric_signature, minimize = self._default_objective_and_direction(
113113
experiment=experiment
114114
)
115-
data = self._lookup_and_validate_data(
115+
df = self._lookup_and_validate(
116116
experiment=experiment, metric_signatures=[metric_signature]
117117
)
118-
if data is None:
118+
if df is None:
119119
# don't stop any trials if we don't get data back
120120
return {}
121121

122-
df = data.full_df
123-
124122
# default checks on `min_progression` and `min_curves`; if not met, don't do
125123
# early stopping at all and return {}
126124
if not self.is_eligible_any(

0 commit comments

Comments
 (0)