Skip to content

Commit 52e30d9

Browse files
ltiaofacebook-github-bot
authored andcommitted
Extract MapDataReplayState for multi-metric experiment replay
Summary: The experiment replay system (`MapDataReplayMetric`, `MapDataReplayRunner`, `replay_experiment`) is hardcoded for single-objective optimization, blocking multi-objective early stopping. `MapDataReplayMetric` conflates data serving with progression state (offset, scaling factor, per-trial step counters), so multiple metrics cannot share a coherent timeline. The scaling factor -- `mean((last_step - offset) / num_points)` -- is an unmotivated heuristic that distorts cross-trial timing. The runner is tightly coupled to a single metric instance. This diff extracts shared state into a `MapDataReplayState` coordinator. The runner owns progression, the metric is a read-only accessor. **Key changes:** - **`MapDataReplayState`** (new): Normalized cursor model. Computes global `min_prog`/`max_prog` across all metrics and trials, advances per-trial cursors by fixed `step_size`, maps to raw progression via `curr_prog = min_prog + cursor * (max_prog - min_prog)`. Serves original MAP_KEY values; downstream ESS normalizes independently. - **`MapDataReplayMetric`** (simplified): Thin wrapper holding a state reference and `metric_signature`. Delegates `fetch_trial_data` to `state.get_data()`. All state-owning attributes removed. - **`MapDataReplayRunner`** (simplified): Takes shared state instead of metric. Calls `advance_trial()` for running trials. - **`replay_experiment`**: Accepts `metrics: list[Metric]`. Builds `OptimizationConfig` or `MultiObjectiveOptimizationConfig` accordingly. Extracts objective thresholds from the historical experiment's config. Re-indexes non-contiguous trial indices to contiguous `0, 1, 2, ...`. Deprecates `num_samples_per_curve` (superseded by `step_size`). - **`estimate_hypothetical_early_stopping_savings`**: Updated for `metrics: list[Metric]`. - **Removed**: `_compute_trial_stats`, `_compute_scaling_factor` (superseded by cursor model). - **Downstream consumers updated**: `ax_sweep_orchestrator`, `early_stopping_healthcheck`, `fblearner/ae/ess_replay/utils` (`metric=` -> `metrics=[...]`). Differential Revision: D96999702
1 parent 43cf130 commit 52e30d9

7 files changed

Lines changed: 893 additions & 260 deletions

File tree

ax/analysis/healthcheck/early_stopping_healthcheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def _report_early_stopping_nudge(
408408
try:
409409
savings = estimate_hypothetical_early_stopping_savings(
410410
experiment=experiment,
411-
metric=metric,
411+
metrics=[metric],
412412
max_pending_trials=self.max_pending_trials,
413413
)
414414
except Exception as e:

ax/early_stopping/experiment_replay.py

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,20 @@
77
# pyre-strict
88

99
import logging
10+
import warnings
1011
from logging import Logger
1112
from time import perf_counter
1213

1314
from ax.adapter.registry import Generators
15+
from ax.core.data import Data
1416
from ax.core.experiment import Experiment
1517
from ax.core.metric import Metric
16-
from ax.core.objective import Objective
17-
from ax.core.optimization_config import OptimizationConfig
18+
from ax.core.objective import MultiObjective, Objective
19+
from ax.core.optimization_config import (
20+
MultiObjectiveOptimizationConfig,
21+
OptimizationConfig,
22+
)
23+
from ax.core.outcome_constraint import OutcomeConstraint
1824
from ax.core.parameter import ParameterType, RangeParameter
1925
from ax.core.search_space import SearchSpace
2026
from ax.early_stopping.dispatch import get_default_ess_or_none
@@ -25,7 +31,7 @@
2531
GenerationStep,
2632
GenerationStrategy,
2733
)
28-
from ax.metrics.map_replay import MapDataReplayMetric
34+
from ax.metrics.map_replay import MapDataReplayMetric, MapDataReplayState
2935
from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions
3036
from ax.runners.map_replay import MapDataReplayRunner
3137
from ax.utils.common.logger import get_logger
@@ -43,16 +49,36 @@ def replay_experiment(
4349
historical_experiment: Experiment,
4450
num_samples_per_curve: int,
4551
max_replay_trials: int,
46-
metric: Metric,
52+
metrics: list[Metric],
4753
max_pending_trials: int,
4854
early_stopping_strategy: BaseEarlyStoppingStrategy | None,
4955
logging_level: int = logging.ERROR,
5056
) -> Experiment | None:
51-
"""A utility function for replaying a historical experiment's data
52-
by initializing a Orchestrator that quickly steps through the existing data.
53-
The main purpose of this function is to compute an hypothetical capacity
54-
savings for a given `early_stopping_strategy`.
57+
"""Replay a historical experiment's data through an Orchestrator.
58+
59+
Initializes an Orchestrator that steps through existing data to compute
60+
hypothetical capacity savings for a given ``early_stopping_strategy``.
61+
Supports both single-objective and multi-objective optimization.
62+
63+
Args:
64+
historical_experiment: The experiment whose data to replay.
65+
num_samples_per_curve: Deprecated. Number of samples per curve for
66+
subsampling. Use ``step_size`` on ``MapDataReplayState`` instead.
67+
max_replay_trials: Maximum number of trials to replay.
68+
metrics: List of metrics to replay. For single-objective, provide
69+
one metric. For multi-objective, provide multiple metrics.
70+
max_pending_trials: Maximum number of pending trials for the
71+
replay orchestrator.
72+
early_stopping_strategy: The early stopping strategy to evaluate.
73+
logging_level: Logging level for the orchestrator.
5574
"""
75+
warnings.warn(
76+
"The `num_samples_per_curve` parameter is deprecated and will be "
77+
"removed in a future release. The `step_size` parameter on "
78+
"`MapDataReplayState` controls replay granularity.",
79+
DeprecationWarning,
80+
stacklevel=2,
81+
)
5682
historical_map_data = historical_experiment.lookup_data()
5783
if not historical_map_data.has_step_column:
5884
logger.warning(
@@ -62,16 +88,51 @@ def replay_experiment(
6288
historical_map_data = historical_map_data.subsample(
6389
limit_rows_per_group=num_samples_per_curve, include_first_last=True
6490
)
65-
replay_metric = MapDataReplayMetric(
66-
name=f"replay_{historical_experiment.name}",
67-
map_data=historical_map_data,
68-
metric_name=metric.name,
69-
lower_is_better=metric.lower_is_better,
70-
)
71-
optimization_config = OptimizationConfig(
72-
objective=Objective(metric=replay_metric),
91+
92+
# Re-index non-contiguous trial indices to contiguous 0, 1, 2, ...
93+
# so that replay trial N maps to the Nth historical trial.
94+
df = historical_map_data.full_df
95+
sorted_trial_indices = sorted(df["trial_index"].unique())
96+
trial_index_map = {old: new for new, old in enumerate(sorted_trial_indices)}
97+
df = df.copy()
98+
df["trial_index"] = df["trial_index"].map(trial_index_map)
99+
historical_map_data = Data(df=df)
100+
101+
metric_signatures = [m.signature for m in metrics]
102+
replay_state = MapDataReplayState(
103+
map_data=historical_map_data, metric_signatures=metric_signatures
73104
)
74-
runner = MapDataReplayRunner(replay_metric=replay_metric)
105+
106+
replay_metrics = [
107+
MapDataReplayMetric(
108+
name=m.name,
109+
replay_state=replay_state,
110+
metric_signature=m.signature,
111+
lower_is_better=m.lower_is_better,
112+
)
113+
for m in metrics
114+
]
115+
116+
if len(replay_metrics) == 1:
117+
optimization_config: OptimizationConfig = OptimizationConfig(
118+
objective=Objective(metric=replay_metrics[0]),
119+
)
120+
else:
121+
# Extract objective thresholds from the historical experiment's config
122+
historical_opt_config = historical_experiment.optimization_config
123+
objective_thresholds: list[OutcomeConstraint] = []
124+
if isinstance(historical_opt_config, MultiObjectiveOptimizationConfig):
125+
objective_thresholds = [
126+
ot.clone() for ot in historical_opt_config.objective_thresholds
127+
]
128+
optimization_config = MultiObjectiveOptimizationConfig(
129+
objective=MultiObjective(
130+
objectives=[Objective(metric=m) for m in replay_metrics]
131+
),
132+
objective_thresholds=objective_thresholds,
133+
)
134+
135+
runner = MapDataReplayRunner(replay_state=replay_state)
75136

76137
# Setup a new experiment with a dummy search space
77138
dummy_search_space = SearchSpace(
@@ -89,10 +150,10 @@ def replay_experiment(
89150
optimization_config=optimization_config,
90151
search_space=dummy_search_space,
91152
runner=runner,
92-
metrics=[replay_metric],
153+
metrics=replay_metrics,
93154
)
94155

95-
# Setup a Orchestrator with a dummy gs to replay the historical experiment
156+
# Setup an Orchestrator with a dummy gs to replay the historical experiment
96157
dummy_sobol_gs = GenerationStrategy(
97158
name="sobol",
98159
steps=[
@@ -101,7 +162,7 @@ def replay_experiment(
101162
)
102163
options = OrchestratorOptions(
103164
max_pending_trials=max_pending_trials,
104-
total_trials=min(len(historical_experiment.trials), max_replay_trials),
165+
total_trials=min(len(sorted_trial_indices), max_replay_trials),
105166
seconds_between_polls_backoff_factor=1.0,
106167
min_seconds_before_poll=0.0,
107168
init_seconds_between_polls=0,
@@ -119,7 +180,7 @@ def replay_experiment(
119180

120181
def estimate_hypothetical_early_stopping_savings(
121182
experiment: Experiment,
122-
metric: Metric,
183+
metrics: list[Metric],
123184
max_pending_trials: int = MAX_PENDING_TRIALS,
124185
) -> float:
125186
"""Estimate hypothetical early stopping savings using experiment replay.
@@ -130,7 +191,7 @@ def estimate_hypothetical_early_stopping_savings(
130191
131192
Args:
132193
experiment: The experiment to analyze.
133-
metric: The metric to use for early stopping replay.
194+
metrics: The metrics to use for early stopping replay.
134195
max_pending_trials: Maximum number of pending trials for the replay
135196
orchestrator. Defaults to 5.
136197
@@ -156,7 +217,7 @@ def estimate_hypothetical_early_stopping_savings(
156217
historical_experiment=experiment,
157218
num_samples_per_curve=REPLAY_NUM_POINTS_PER_CURVE,
158219
max_replay_trials=MAX_REPLAY_TRIALS,
159-
metric=metric,
220+
metrics=metrics,
160221
max_pending_trials=max_pending_trials,
161222
early_stopping_strategy=default_ess,
162223
)

0 commit comments

Comments
 (0)